Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
4dcd6ab0
Unverified
Commit
4dcd6ab0
authored
May 20, 2019
by
Gao, Xiang
Committed by
GitHub
May 20, 2019
Browse files
Use a seperate optimizer to pretrain, and pretrain more (#226)
parent
2ad4126f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
29 deletions
+34
-29
examples/nnp_training.py
examples/nnp_training.py
+34
-29
No files found.
examples/nnp_training.py
View file @
4dcd6ab0
...
@@ -234,15 +234,44 @@ optimizer = torchani.optim.AdamW([
...
@@ -234,15 +234,44 @@ optimizer = torchani.optim.AdamW([
latest_checkpoint
=
'latest.pt'
latest_checkpoint
=
'latest.pt'
pretrained
=
os
.
path
.
isfile
(
latest_checkpoint
)
pretrained
=
os
.
path
.
isfile
(
latest_checkpoint
)
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint
# helper function to convert energy unit from Hartree to kcal/mol
def
hartree2kcal
(
x
):
return
627.509
*
x
def
validate
():
# run validation
mse_sum
=
torch
.
nn
.
MSELoss
(
reduction
=
'sum'
)
total_mse
=
0.0
count
=
0
for
batch_x
,
batch_y
in
validation
:
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
count
+=
predicted_energies
.
shape
[
0
]
return
hartree2kcal
(
math
.
sqrt
(
total_mse
/
count
))
###############################################################################
###############################################################################
# If the model is not pretrained yet, we need to run the pretrain.
# If the model is not pretrained yet, we need to run the pretrain.
pretrain_
epoches
=
10
pretrain_
criterion
=
10
# kcal/mol
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
mse
=
torch
.
nn
.
MSELoss
(
reduction
=
'none'
)
if
not
pretrained
:
if
not
pretrained
:
print
(
"pre-training..."
)
print
(
"pre-training..."
)
epoch
=
0
epoch
=
0
for
_
in
range
(
pretrain_epoches
):
rmse
=
math
.
inf
pretrain_optimizer
=
torch
.
optim
.
Adam
(
nn
.
parameters
())
while
rmse
>
pretrain_criterion
:
for
batch_x
,
batch_y
in
tqdm
.
tqdm
(
training
):
for
batch_x
,
batch_y
in
tqdm
.
tqdm
(
training
):
true_energies
=
batch_y
[
'energies'
]
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
predicted_energies
=
[]
...
@@ -254,9 +283,11 @@ if not pretrained:
...
@@ -254,9 +283,11 @@ if not pretrained:
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
num_atoms
=
torch
.
cat
(
num_atoms
).
to
(
true_energies
.
dtype
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
loss
=
(
mse
(
predicted_energies
,
true_energies
)
/
num_atoms
).
mean
()
optimizer
.
zero_grad
()
pretrain_
optimizer
.
zero_grad
()
loss
.
backward
()
loss
.
backward
()
optimizer
.
step
()
optimizer
.
step
()
rmse
=
validate
()
print
(
'RMSE:'
,
rmse
,
'Target RMSE:'
,
pretrain_criterion
)
torch
.
save
({
torch
.
save
({
'nn'
:
nn
.
state_dict
(),
'nn'
:
nn
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
'optimizer'
:
optimizer
.
state_dict
(),
...
@@ -278,32 +309,6 @@ optimizer.load_state_dict(checkpoint['optimizer'])
...
@@ -278,32 +309,6 @@ optimizer.load_state_dict(checkpoint['optimizer'])
if
'scheduler'
in
checkpoint
:
if
'scheduler'
in
checkpoint
:
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
scheduler
.
load_state_dict
(
checkpoint
[
'scheduler'
])
###############################################################################
# During training, we need to validate on validation set and if validation error
# is better than the best, then save the new best model to a checkpoint
# helper function to convert energy unit from Hartree to kcal/mol
def
hartree2kcal
(
x
):
return
627.509
*
x
def
validate
():
# run validation
mse_sum
=
torch
.
nn
.
MSELoss
(
reduction
=
'sum'
)
total_mse
=
0.0
count
=
0
for
batch_x
,
batch_y
in
validation
:
true_energies
=
batch_y
[
'energies'
]
predicted_energies
=
[]
for
chunk_species
,
chunk_coordinates
in
batch_x
:
_
,
chunk_energies
=
model
((
chunk_species
,
chunk_coordinates
))
predicted_energies
.
append
(
chunk_energies
)
predicted_energies
=
torch
.
cat
(
predicted_energies
)
total_mse
+=
mse_sum
(
predicted_energies
,
true_energies
).
item
()
count
+=
predicted_energies
.
shape
[
0
]
return
hartree2kcal
(
math
.
sqrt
(
total_mse
/
count
))
###############################################################################
###############################################################################
# Finally, we come to the training loop.
# Finally, we come to the training loop.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment