Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
73e447f0
Unverified
Commit
73e447f0
authored
Aug 03, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 03, 2018
Browse files
improve energy shifter (#52)
parent
59b31d84
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
67 additions
and
76 deletions
+67
-76
examples/energy_force.py
examples/energy_force.py
+2
-3
examples/model.py
examples/model.py
+1
-1
examples/nnp_training.py
examples/nnp_training.py
+3
-3
examples/training-benchmark.py
examples/training-benchmark.py
+3
-3
tests/test_energies.py
tests/test_energies.py
+4
-4
tests/test_energyshifter.py
tests/test_energyshifter.py
+27
-0
tests/test_ignite.py
tests/test_ignite.py
+6
-5
torchani/energyshifter.py
torchani/energyshifter.py
+21
-57
No files found.
examples/energy_force.py
View file @
73e447f0
...
@@ -12,8 +12,8 @@ aev_computer = torchani.SortedAEV(const_file=const_file)
...
@@ -12,8 +12,8 @@ aev_computer = torchani.SortedAEV(const_file=const_file)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nn
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
from_
=
network_dir
,
nn
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
,
from_
=
network_dir
,
ensemble
=
8
)
ensemble
=
8
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nn
)
shift_energy
=
torchani
.
EnergyShifter
(
aev_computer
.
species
,
sae_file
)
shift_energy
=
torchani
.
EnergyShifter
(
sae_file
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nn
,
shift_energy
)
coordinates
=
torch
.
tensor
([[[
0.03192167
,
0.00638559
,
0.01301679
],
coordinates
=
torch
.
tensor
([[[
0.03192167
,
0.00638559
,
0.01301679
],
[
-
0.83140486
,
0.39370209
,
-
0.26395324
],
[
-
0.83140486
,
0.39370209
,
-
0.26395324
],
...
@@ -25,7 +25,6 @@ species = ['C', 'H', 'H', 'H', 'H']
...
@@ -25,7 +25,6 @@ species = ['C', 'H', 'H', 'H', 'H']
_
,
energy
=
model
((
species
,
coordinates
))
_
,
energy
=
model
((
species
,
coordinates
))
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
energy
=
shift_energy
.
add_sae
(
energy
,
species
)
force
=
-
derivative
force
=
-
derivative
print
(
'Energy:'
,
energy
.
item
())
print
(
'Energy:'
,
energy
.
item
())
...
...
examples/model.py
View file @
73e447f0
...
@@ -41,4 +41,4 @@ def get_or_create_model(filename, benchmark=False,
...
@@ -41,4 +41,4 @@ def get_or_create_model(filename, benchmark=False,
model
.
load_state_dict
(
torch
.
load
(
filename
))
model
.
load_state_dict
(
torch
.
load
(
filename
))
else
:
else
:
torch
.
save
(
model
.
state_dict
(),
filename
)
torch
.
save
(
model
.
state_dict
(),
filename
)
return
model
.
to
(
device
)
return
model
.
to
(
device
)
,
torchani
.
EnergyShifter
(
aev_computer
.
species
)
examples/nnp_training.py
View file @
73e447f0
...
@@ -48,13 +48,13 @@ device = torch.device(parser.device)
...
@@ -48,13 +48,13 @@ device = torch.device(parser.device)
writer
=
tensorboardX
.
SummaryWriter
(
log_dir
=
parser
.
log
)
writer
=
tensorboardX
.
SummaryWriter
(
log_dir
=
parser
.
log
)
start
=
timeit
.
default_timer
()
start
=
timeit
.
default_timer
()
shift_energy
=
torchani
.
EnergyShifter
()
nnp
,
shift_energy
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
,
device
=
device
)
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
training
,
validation
,
testing
=
torchani
.
data
.
load_or_create
(
parser
.
dataset_checkpoint
,
parser
.
dataset_path
,
parser
.
chunk_size
,
parser
.
dataset_checkpoint
,
parser
.
dataset_path
,
parser
.
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_
subtract_
sae
])
device
=
device
,
transform
=
[
shift_energy
.
subtract_
from_dataset
])
training
=
torchani
.
data
.
dataloader
(
training
,
parser
.
batch_chunks
)
training
=
torchani
.
data
.
dataloader
(
training
,
parser
.
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
parser
.
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
parser
.
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
...
...
examples/training-benchmark.py
View file @
73e447f0
...
@@ -24,12 +24,12 @@ parser = parser.parse_args()
...
@@ -24,12 +24,12 @@ parser = parser.parse_args()
# set up benchmark
# set up benchmark
device
=
torch
.
device
(
parser
.
device
)
device
=
torch
.
device
(
parser
.
device
)
shift_energy
=
torchani
.
EnergyShifter
()
nnp
,
shift_energy
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
,
device
=
device
)
dataset
=
torchani
.
data
.
ANIDataset
(
dataset
=
torchani
.
data
.
ANIDataset
(
parser
.
dataset_path
,
parser
.
chunk_size
,
device
=
device
,
parser
.
dataset_path
,
parser
.
chunk_size
,
device
=
device
,
transform
=
[
shift_energy
.
dataset_
subtract_
sae
])
transform
=
[
shift_energy
.
subtract_
from_dataset
])
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
parser
.
batch_chunks
)
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
parser
.
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
,
device
=
device
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
...
...
tests/test_energies.py
View file @
73e447f0
...
@@ -16,13 +16,13 @@ class TestEnergies(unittest.TestCase):
...
@@ -16,13 +16,13 @@ class TestEnergies(unittest.TestCase):
aev_computer
=
torchani
.
SortedAEV
()
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
shift_energy
=
torchani
.
EnergyShifter
(
aev_computer
.
species
)
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
,
shift_energy
)
def
_test_molecule
(
self
,
coordinates
,
species
,
energies
):
def
_test_molecule
(
self
,
coordinates
,
species
,
energies
):
shift_energy
=
torchani
.
EnergyShifter
()
_
,
energies_
=
self
.
model
((
species
,
coordinates
))
_
,
energies_
=
self
.
model
((
species
,
coordinates
))
energies_
=
shift_energy
.
add_sae
(
energies_
.
squeeze
(),
species
)
max_diff
=
(
energies
-
energies_
.
squeeze
()).
abs
().
max
().
item
()
max_diff
=
(
energies
-
energies_
).
abs
().
max
().
item
()
self
.
assertLess
(
max_diff
,
self
.
tolerance
)
self
.
assertLess
(
max_diff
,
self
.
tolerance
)
def
testGDB
(
self
):
def
testGDB
(
self
):
...
...
tests/test_energyshifter.py
0 → 100644
View file @
73e447f0
import
torch
import
torchani
import
unittest
import
random
class
TestEnergyShifter
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
tol
=
1e-5
self
.
species
=
torchani
.
SortedAEV
().
species
self
.
prepare
=
torchani
.
PrepareInput
(
self
.
species
)
self
.
shift_energy
=
torchani
.
EnergyShifter
(
self
.
species
)
def
testSAEMatch
(
self
):
for
_
in
range
(
10
):
k
=
random
.
choice
(
range
(
5
,
30
))
species
=
random
.
choices
(
self
.
species
,
k
=
k
)
species_tensor
=
self
.
prepare
.
species_to_tensor
(
species
,
torch
.
device
(
'cpu'
))
e1
=
self
.
shift_energy
.
sae_from_list
(
species
)
e2
=
self
.
shift_energy
.
sae_from_tensor
(
species_tensor
)
self
.
assertLess
(
abs
(
e1
-
e2
),
self
.
tol
)
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_ignite.py
View file @
73e447f0
...
@@ -17,14 +17,15 @@ if sys.version_info.major >= 3:
...
@@ -17,14 +17,15 @@ if sys.version_info.major >= 3:
class
TestIgnite
(
unittest
.
TestCase
):
class
TestIgnite
(
unittest
.
TestCase
):
def
testIgnite
(
self
):
def
testIgnite
(
self
):
shift_energy
=
torchani
.
EnergyShifter
()
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
transform
=
[
shift_energy
.
dataset_subtract_sae
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
loader
=
torchani
.
data
.
dataloader
(
ds
,
1
)
aev_computer
=
torchani
.
SortedAEV
()
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
shift_energy
=
torchani
.
EnergyShifter
(
aev_computer
.
species
)
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
,
transform
=
[
shift_energy
.
subtract_from_dataset
])
ds
=
torch
.
utils
.
data
.
Subset
(
ds
,
[
0
])
loader
=
torchani
.
data
.
dataloader
(
ds
,
1
)
class
Flatten
(
torch
.
nn
.
Module
):
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
...
torchani/energyshifter.py
View file @
73e447f0
import
torch
from
.env
import
buildin_sae_file
from
.env
import
buildin_sae_file
class
EnergyShifter
:
class
EnergyShifter
(
torch
.
nn
.
Module
):
"""Class that deal with self atomic energies.
Attributes
def
__init__
(
self
,
species
,
self_energy_file
=
buildin_sae_file
):
----------
super
(
EnergyShifter
,
self
).
__init__
()
self_energies : dict
The dictionary that stores self energies of species.
"""
def
__init__
(
self
,
self_energy_file
=
buildin_sae_file
):
# load self energies
# load self energies
self
.
self_energies
=
{}
self
.
self_energies
=
{}
with
open
(
self_energy_file
)
as
f
:
with
open
(
self_energy_file
)
as
f
:
...
@@ -22,55 +17,24 @@ class EnergyShifter:
...
@@ -22,55 +17,24 @@ class EnergyShifter:
self
.
self_energies
[
name
]
=
value
self
.
self_energies
[
name
]
=
value
except
Exception
:
except
Exception
:
pass
# ignore unrecognizable line
pass
# ignore unrecognizable line
self_energies_tensor
=
[
self
.
self_energies
[
s
]
for
s
in
species
]
self
.
register_buffer
(
'self_energies_tensor'
,
torch
.
tensor
(
self_energies_tensor
,
dtype
=
torch
.
double
))
def
subtract_sae
(
self
,
energies
,
species
):
def
sae_from_list
(
self
,
species
):
"""Subtract self atomic energies from `energies`.
energies
=
[
self
.
self_energies
[
i
]
for
i
in
species
]
return
sum
(
energies
)
Parameters
----------
energies : pytorch tensor of `dtype`
The tensor of any shape that stores the raw energies.
species : list of str
The list specifying the species of each atom. The length of the
list must be the same as the number of atoms.
Returns
-------
pytorch tensor of `dtype`
The tensor of the same shape as `energies` that stores the energies
with self atomic energies subtracted.
"""
s
=
0
for
i
in
species
:
s
+=
self
.
self_energies
[
i
]
return
energies
-
s
def
add_sae
(
self
,
energies
,
species
):
def
sae_from_tensor
(
self
,
species
):
"""Add self atomic energies to `energies`
return
self
.
self_energies_tensor
[
species
].
sum
().
item
()
Parameters
def
subtract_from_dataset
(
self
,
data
):
----------
sae
=
self
.
sae_from_list
(
data
[
'species'
])
energies : pytorch tensor of `dtype`
data
[
'energies'
]
-=
sae
The tensor of any shape that stores the energies excluding self
atomic energies.
species : list of str
The list specifying the species of each atom. The length of the
list must be the same as the number of atoms.
Returns
-------
pytorch tensor of `dtype`
The tensor of the same shape as `energies` that stores the raw
energies, i.e. the energy including self atomic energies.
"""
s
=
0
for
i
in
species
:
s
+=
self
.
self_energies
[
i
]
return
energies
+
s
def
dataset_subtract_sae
(
self
,
data
):
"""Allow object of this class to be used as transforms of pytorch's
dataset.
"""
data
[
'energies'
]
=
self
.
subtract_sae
(
data
[
'energies'
],
data
[
'species'
])
return
data
return
data
def
forward
(
self
,
species_energies
):
species
,
energies
=
species_energies
sae
=
self
.
sae_from_tensor
(
species
)
return
species
,
energies
+
sae
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment