Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
e1c78efe
Commit
e1c78efe
authored
Jun 06, 2019
by
Farhad Ramezanghorbani
Committed by
Gao, Xiang
Jun 06, 2019
Browse files
Updated nnp_training tutorial to be as close to NeuroChem as possible (#245)
parent
dc8930ee
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
61 deletions
+44
-61
examples/nnp_training.py
examples/nnp_training.py
+44
-61
No files found.
examples/nnp_training.py
View file @
e1c78efe
...
@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential(
...
@@ -156,6 +156,27 @@ O_network = torch.nn.Sequential(
nn
=
torchani
.
ANIModel
([
H_network
,
C_network
,
N_network
,
O_network
])
nn
=
torchani
.
ANIModel
([
H_network
,
C_network
,
N_network
,
O_network
])
print
(
nn
)
print
(
nn
)
###############################################################################
# Initialize the weights and biases.
#
# .. note::
# Pytorch default initialization for the weights and biases in linear layers
# is Kaiming uniform. See: `TORCH.NN.MODULES.LINEAR`_
# We initialize the weights similarly but from the normal distribution.
# The biases were initialized to zero.
#
# .. _TORCH.NN.MODULES.LINEAR:
# https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
def
init_params
(
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
torch
.
nn
.
init
.
kaiming_normal_
(
m
.
weight
,
a
=
1.0
)
torch
.
nn
.
init
.
zeros_
(
m
.
bias
)
nn
.
apply
(
init_params
)
###############################################################################
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
).
to
(
device
)
model
=
torch
.
nn
.
Sequential
(
aev_computer
,
nn
).
to
(
device
)
...
@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([
...
@@ -218,19 +239,25 @@ optimizer = torchani.optim.AdamW([
])
])
###############################################################################
###############################################################################
#
The way ANI trains a neural network potential looks like this:
#
Setting up a learning rate scheduler to do learning rate decay
#
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
factor
=
0.5
,
patience
=
100
,
threshold
=
0
)
# Phase 1: Pretrain the model by minimizing MSE loss
#
#
##############################################################################
#
Phase 2:
Train the model by minimizing the
exponential
loss, until validation
# Train the model by minimizing the
MSE
loss, until validation
RMSE no longer
#
RMSE no longer
improves
for
a certain steps, decay the learning rate and repeat
# improves
during
a certain
number of
steps, decay the learning rate and repeat
# the same process, stop until the learning rate is smaller than a
certain number
.
# the same process, stop until the learning rate is smaller than a
threshold
.
#
#
# We first read the checkpoint files to find where we are. We use `latest.pt`
# We first read the checkpoint files to restart training. We use `latest.pt`
# to store current training state. If `latest.pt` does not exist, this
# to store current training state.
# this means the pretraining has not been finished yet.
latest_checkpoint
=
'latest.pt'
latest_checkpoint
=
'latest.pt'
pretrained
=
os
.
path
.
isfile
(
latest_checkpoint
)
###############################################################################
# Resume training from previously saved checkpoints:
if
os
.
path
.
isfile
(
latest_checkpoint
):
checkpoint
=
torch
.
load
(
latest_checkpoint
)
model
.
load_state_dict
(
checkpoint
[
'nn'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
###############################################################################
###############################################################################
# During training, we need to validate on validation set and if validation error
# During training, we need to validate on validation set and if validation error
...
@@ -259,61 +286,18 @@ def validate():
...
@@ -259,61 +286,18 @@ def validate():
return
hartree2kcal
(
math
.
sqrt
(
total_mse
/
count
))
return
hartree2kcal
(
math
.
sqrt
(
total_mse
/
count
))
###############################################################################
# If the model is not pretrained yet, we need to run the pretrain.
pretrain_criterion
=
10
# kcal/mol
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
if
not
pretrained
:
print
(
"pre-training..."
)
epoch
=
0
rmse
=
math
.
inf
pretrain_optimizer
=
torch
.
optim
.
Adam
(
nn
.
parameters
())
while
rmse
>
pretrain_criterion
:
for
batch_x
,
batch_y
in
tqdm
.
tqdm
(
training
):
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
num_atoms
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
num_atoms
.
append
((
chunk_species
>=
0
).
sum
(
dim
=
1
))
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
predicted_energies
.
append
(
chunk_energies
)
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
pretrain_optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
rmse
=
validate
()
print
(
'RMSE:'
,
rmse
,
'Target RMSE:'
,
pretrain_criterion
)
torch
.
save
({
'nn'
:
nn
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
},
latest_checkpoint
)
###############################################################################
# For phase 2, we need a learning rate scheduler to do learning rate decay
scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
optimizer
,
factor
=
0.5
,
patience
=
100
)
###############################################################################
###############################################################################
# We will also use TensorBoard to visualize our training process
# We will also use TensorBoard to visualize our training process
tensorboard
=
torch
.
utils
.
tensorboard
.
SummaryWriter
()
tensorboard
=
torch
.
utils
.
tensorboard
.
SummaryWriter
()
###############################################################################
# Resume training from previously saved checkpoints:
checkpoint
=
torch
.
load
(
latest_checkpoint
)
nn
.
load_state_dict
(
checkpoint
[
'nn'
])
optimizer
.
load_state_dict
(
checkpoint
[
'optimizer'
])
if
'scheduler'
in
checkpoint
:
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
###############################################################################
###############################################################################
# Finally, we come to the training loop.
# Finally, we come to the training loop.
#
#
# In this tutorial, we are setting the maximum epoch to a very small number,
# In this tutorial, we are setting the maximum epoch to a very small number,
# only to make this demo terminate fast. For serious training, this should be
# only to make this demo terminate fast. For serious training, this should be
# set to a much larger value
# set to a much larger value
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
print
(
"training starting from epoch"
,
scheduler
.
last_epoch
+
1
)
print
(
"training starting from epoch"
,
scheduler
.
last_epoch
+
1
)
max_epochs
=
200
max_epochs
=
200
early_stopping_learning_rate
=
1.0E-5
early_stopping_learning_rate
=
1.0E-5
...
@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt'
...
@@ -321,16 +305,16 @@ best_model_checkpoint = 'best.pt'
for
_
in
range
(
scheduler
.
last_epoch
+
1
,
max_epochs
):
for
_
in
range
(
scheduler
.
last_epoch
+
1
,
max_epochs
):
rmse
=
validate
()
rmse
=
validate
()
print
(
'RMSE:'
,
rmse
,
'at epoch'
,
scheduler
.
last_epoch
)
print
(
'RMSE:'
,
rmse
,
'at epoch'
,
scheduler
.
last_epoch
+
1
)
learning_rate
=
optimizer
.
param_groups
[
0
][
'lr'
]
learning_rate
=
optimizer
.
param_groups
[
0
][
'lr'
]
if
learning_rate
<
early_stopping_learning_rate
:
if
learning_rate
<
early_stopping_learning_rate
:
break
break
tensorboard
.
add_scalar
(
'validation_rmse'
,
rmse
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'validation_rmse'
,
rmse
,
scheduler
.
last_epoch
+
1
)
tensorboard
.
add_scalar
(
'best_validation_rmse'
,
scheduler
.
best
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'best_validation_rmse'
,
scheduler
.
best
,
scheduler
.
last_epoch
+
1
)
tensorboard
.
add_scalar
(
'learning_rate'
,
learning_rate
,
scheduler
.
last_epoch
)
tensorboard
.
add_scalar
(
'learning_rate'
,
learning_rate
,
scheduler
.
last_epoch
+
1
)
# checkpoint
# checkpoint
if
scheduler
.
is_better
(
rmse
,
scheduler
.
best
):
if
scheduler
.
is_better
(
rmse
,
scheduler
.
best
):
...
@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
...
@@ -349,7 +333,6 @@ for _ in range(scheduler.last_epoch + 1, max_epochs):
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
loss
=
0.5
*
(
torch
.
exp
(
2
*
loss
)
-
1
)
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment