Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
84439caf
Unverified
Commit
84439caf
authored
Aug 03, 2018
by
Gao, Xiang
Committed by
GitHub
Aug 03, 2018
Browse files
merge Container and BatchModel to provide richer information in output (#50)
parent
1d8bba37
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
64 additions
and
102 deletions
+64
-102
examples/energy_force.py
examples/energy_force.py
+1
-1
examples/model.py
examples/model.py
+1
-1
examples/nnp_training.py
examples/nnp_training.py
+3
-4
examples/training-benchmark.py
examples/training-benchmark.py
+2
-3
tests/test_batch.py
tests/test_batch.py
+0
-32
tests/test_data.py
tests/test_data.py
+1
-1
tests/test_energies.py
tests/test_energies.py
+3
-3
tests/test_ensemble.py
tests/test_ensemble.py
+2
-2
tests/test_forces.py
tests/test_forces.py
+1
-1
tests/test_ignite.py
tests/test_ignite.py
+4
-5
torchani/data.py
torchani/data.py
+18
-16
torchani/ignite/__init__.py
torchani/ignite/__init__.py
+2
-4
torchani/ignite/container.py
torchani/ignite/container.py
+10
-8
torchani/ignite/loss_metrics.py
torchani/ignite/loss_metrics.py
+14
-2
torchani/models/__init__.py
torchani/models/__init__.py
+1
-2
torchani/models/ani_model.py
torchani/models/ani_model.py
+1
-1
torchani/models/batch.py
torchani/models/batch.py
+0
-16
No files found.
examples/energy_force.py
View file @
84439caf
...
...
@@ -23,7 +23,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
requires_grad
=
True
)
species
=
[
'C'
,
'H'
,
'H'
,
'H'
,
'H'
]
energy
=
model
((
species
,
coordinates
))
_
,
energy
=
model
((
species
,
coordinates
))
derivative
=
torch
.
autograd
.
grad
(
energy
.
sum
(),
coordinates
)[
0
]
energy
=
shift_energy
.
add_sae
(
energy
,
species
)
force
=
-
derivative
...
...
examples/model.py
View file @
84439caf
...
...
@@ -34,7 +34,7 @@ def get_or_create_model(filename, benchmark=False,
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
flatten
()
return
x
[
0
],
x
[
1
]
.
flatten
()
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
model
,
Flatten
())
if
os
.
path
.
isfile
(
filename
):
...
...
examples/nnp_training.py
View file @
84439caf
...
...
@@ -55,17 +55,16 @@ training, validation, testing = torchani.data.load_or_create(
training
=
torchani
.
data
.
dataloader
(
training
,
parser
.
batch_chunks
)
validation
=
torchani
.
data
.
dataloader
(
validation
,
parser
.
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
parser
.
model_checkpoint
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
parser
.
optim_args
=
json
.
loads
(
parser
.
optim_args
)
optimizer
=
getattr
(
torch
.
optim
,
parser
.
optimizer
)
optimizer
=
optimizer
(
nnp
.
parameters
(),
**
parser
.
optim_args
)
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
ignite
.
energy_mse_loss
)
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
)
)
evaluator
=
ignite
.
engine
.
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
ignite
.
energy_rmse_metric
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
...
...
examples/training-benchmark.py
View file @
84439caf
...
...
@@ -30,12 +30,11 @@ dataset = torchani.data.ANIDataset(
transform
=
[
shift_energy
.
dataset_subtract_sae
])
dataloader
=
torchani
.
data
.
dataloader
(
dataset
,
parser
.
batch_chunks
)
nnp
=
model
.
get_or_create_model
(
'/tmp/model.pt'
,
True
,
device
=
device
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
nnp
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
nnp
})
optimizer
=
torch
.
optim
.
Adam
(
nnp
.
parameters
())
trainer
=
ignite
.
engine
.
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
ignite
.
energy_mse_loss
)
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
)
)
@
trainer
.
on
(
ignite
.
engine
.
Events
.
EPOCH_STARTED
)
...
...
tests/test_batch.py
deleted
100644 → 0
View file @
1d8bba37
import
sys
if
sys
.
version_info
.
major
>=
3
:
import
os
import
unittest
import
torch
import
torchani
import
torchani.data
import
itertools
path
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
path
=
os
.
path
.
join
(
path
,
'../dataset'
)
chunksize
=
32
batch_chunks
=
32
class
TestBatch
(
unittest
.
TestCase
):
def
testBatchLoadAndInference
(
self
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
loader
=
torchani
.
data
.
dataloader
(
ds
,
batch_chunks
)
aev_computer
=
torchani
.
SortedAEV
()
prepare
=
torchani
.
PrepareInput
(
aev_computer
.
species
)
nnp
=
torchani
.
models
.
NeuroChemNNP
(
aev_computer
.
species
)
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
batch_nnp
=
torchani
.
models
.
BatchModel
(
model
)
for
batch_input
,
batch_output
in
itertools
.
islice
(
loader
,
10
):
batch_output_
=
batch_nnp
(
batch_input
).
squeeze
()
self
.
assertListEqual
(
list
(
batch_output_
.
shape
),
list
(
batch_output
[
'energies'
].
shape
))
if
__name__
==
'__main__'
:
unittest
.
main
()
tests/test_data.py
View file @
84439caf
...
...
@@ -12,7 +12,7 @@ if sys.version_info.major >= 3:
def
_test_chunksize
(
self
,
chunksize
):
ds
=
torchani
.
data
.
ANIDataset
(
path
,
chunksize
)
for
i
in
ds
:
for
i
,
_
in
ds
:
self
.
assertLessEqual
(
i
[
'coordinates'
].
shape
[
0
],
chunksize
)
def
testChunk64
(
self
):
...
...
tests/test_energies.py
View file @
84439caf
...
...
@@ -19,9 +19,9 @@ class TestEnergies(unittest.TestCase):
self
.
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
)
def
_test_molecule
(
self
,
coordinates
,
species
,
energies
):
shift_energy
=
torchani
.
EnergyShifter
(
torchani
.
buildin_sae_file
)
energies_
=
self
.
model
((
species
,
coordinates
))
.
squeeze
()
energies_
=
shift_energy
.
add_sae
(
energies_
,
species
)
shift_energy
=
torchani
.
EnergyShifter
()
_
,
energies_
=
self
.
model
((
species
,
coordinates
))
energies_
=
shift_energy
.
add_sae
(
energies_
.
squeeze
()
,
species
)
max_diff
=
(
energies
-
energies_
).
abs
().
max
().
item
()
self
.
assertLess
(
max_diff
,
self
.
tolerance
)
...
...
tests/test_ensemble.py
View file @
84439caf
...
...
@@ -28,9 +28,9 @@ class TestEnsemble(unittest.TestCase):
for
i
in
range
(
n
)]
models
=
[
torch
.
nn
.
Sequential
(
prepare
,
aev
,
m
)
for
m
in
models
]
energy1
=
ensemble
((
species
,
coordinates
))
_
,
energy1
=
ensemble
((
species
,
coordinates
))
force1
=
torch
.
autograd
.
grad
(
energy1
.
sum
(),
coordinates
)[
0
]
energy2
=
[
m
((
species
,
coordinates
))
for
m
in
models
]
energy2
=
[
m
((
species
,
coordinates
))
[
1
]
for
m
in
models
]
energy2
=
sum
(
energy2
)
/
n
force2
=
torch
.
autograd
.
grad
(
energy2
.
sum
(),
coordinates
)[
0
]
energy_diff
=
(
energy1
-
energy2
).
abs
().
max
().
item
()
...
...
tests/test_forces.py
View file @
84439caf
...
...
@@ -19,7 +19,7 @@ class TestForce(unittest.TestCase):
def
_test_molecule
(
self
,
coordinates
,
species
,
forces
):
coordinates
=
torch
.
tensor
(
coordinates
,
requires_grad
=
True
)
energies
=
self
.
model
((
species
,
coordinates
))
_
,
energies
=
self
.
model
((
species
,
coordinates
))
derivative
=
torch
.
autograd
.
grad
(
energies
.
sum
(),
coordinates
)[
0
]
max_diff
=
(
forces
+
derivative
).
abs
().
max
().
item
()
self
.
assertLess
(
max_diff
,
self
.
tolerance
)
...
...
tests/test_ignite.py
View file @
84439caf
...
...
@@ -28,16 +28,15 @@ if sys.version_info.major >= 3:
class
Flatten
(
torch
.
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
.
flatten
()
return
x
[
0
],
x
[
1
]
.
flatten
()
model
=
torch
.
nn
.
Sequential
(
prepare
,
aev_computer
,
nnp
,
Flatten
())
batch_nnp
=
torchani
.
models
.
BatchModel
(
model
)
container
=
torchani
.
ignite
.
Container
({
'energies'
:
batch_nnp
})
container
=
torchani
.
ignite
.
Container
({
'energies'
:
model
})
optimizer
=
torch
.
optim
.
Adam
(
container
.
parameters
())
trainer
=
create_supervised_trainer
(
container
,
optimizer
,
torchani
.
ignite
.
energy_mse_loss
)
container
,
optimizer
,
torchani
.
ignite
.
MSELoss
(
'energies'
)
)
evaluator
=
create_supervised_evaluator
(
container
,
metrics
=
{
'RMSE'
:
torchani
.
ignite
.
energy_rmse_metric
'RMSE'
:
torchani
.
ignite
.
RMSEMetric
(
'energies'
)
})
@
trainer
.
on
(
Events
.
COMPLETED
)
...
...
torchani/data.py
View file @
84439caf
...
...
@@ -5,6 +5,7 @@ from .pyanitools import anidataloader
import
torch
import
torch.utils.data
as
data
import
pickle
import
collections
class
ANIDataset
(
Dataset
):
...
...
@@ -64,7 +65,9 @@ class ANIDataset(Dataset):
self
.
chunks
=
chunks
def
__getitem__
(
self
,
idx
):
return
self
.
chunks
[
idx
]
chunk
=
self
.
chunks
[
idx
]
input_chunk
=
{
k
:
chunk
[
k
]
for
k
in
(
'coordinates'
,
'species'
)}
return
input_chunk
,
chunk
def
__len__
(
self
):
return
len
(
self
.
chunks
)
...
...
@@ -89,22 +92,21 @@ def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
return
training
,
validation
,
testing
def
_collate
(
batch
):
input_keys
=
[
'coordinates'
,
'species'
]
inputs
=
[{
k
:
i
[
k
]
for
k
in
input_keys
}
for
i
in
batch
]
outputs
=
{}
for
i
in
batch
:
for
j
in
i
:
if
j
in
input_keys
:
continue
if
j
not
in
outputs
:
outputs
[
j
]
=
[]
outputs
[
j
].
append
(
i
[
j
])
for
i
in
outputs
:
outputs
[
i
]
=
torch
.
cat
(
outputs
[
i
])
return
inputs
,
outputs
def
collate
(
batch
):
no_collate
=
[
'coordinates'
,
'species'
]
if
isinstance
(
batch
[
0
],
torch
.
Tensor
):
return
torch
.
cat
(
batch
)
elif
isinstance
(
batch
[
0
],
collections
.
Mapping
):
return
{
key
:
((
lambda
x
:
x
)
if
key
in
no_collate
else
collate
)
([
d
[
key
]
for
d
in
batch
])
for
key
in
batch
[
0
]}
elif
isinstance
(
batch
[
0
],
collections
.
Sequence
):
transposed
=
zip
(
*
batch
)
return
[
collate
(
samples
)
for
samples
in
transposed
]
else
:
raise
ValueError
(
'Unexpected element type: {}'
.
format
(
type
(
batch
[
0
])))
def
dataloader
(
dataset
,
batch_chunks
,
shuffle
=
True
,
**
kwargs
):
return
DataLoader
(
dataset
,
batch_chunks
,
shuffle
,
collate_fn
=
_
collate
,
**
kwargs
)
collate_fn
=
collate
,
**
kwargs
)
torchani/ignite/__init__.py
View file @
84439caf
from
.container
import
Container
from
.loss_metrics
import
DictLoss
,
DictMetric
,
energy_mse_loss
,
\
energy_rmse_metric
from
.loss_metrics
import
DictLoss
,
DictMetric
,
MSELoss
,
RMSEMetric
__all__
=
[
'Container'
,
'DictLoss'
,
'DictMetric'
,
'energy_mse_loss'
,
'energy_rmse_metric'
]
__all__
=
[
'Container'
,
'DictLoss'
,
'DictMetric'
,
'MSELoss'
,
'RMSEMetric'
]
torchani/ignite/container.py
View file @
84439caf
import
torch
from
..
models
import
BatchModel
from
..
data
import
collate
class
Container
(
torch
.
nn
.
Module
):
...
...
@@ -8,13 +8,15 @@ class Container(torch.nn.Module):
super
(
Container
,
self
).
__init__
()
self
.
keys
=
models
.
keys
()
for
i
in
models
:
if
not
isinstance
(
models
[
i
],
BatchModel
):
raise
ValueError
(
'Container must contain batch models'
)
setattr
(
self
,
'model_'
+
i
,
models
[
i
])
def
forward
(
self
,
batch
):
output
=
{}
for
i
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
i
)
output
[
i
]
=
model
(
batch
)
return
output
all_results
=
[]
for
i
in
zip
(
batch
[
'species'
],
batch
[
'coordinates'
]):
results
=
{}
for
k
in
self
.
keys
:
model
=
getattr
(
self
,
'model_'
+
k
)
_
,
results
[
k
]
=
model
(
i
)
all_results
.
append
(
results
)
batch
.
update
(
collate
(
all_results
))
return
batch
torchani/ignite/loss_metrics.py
View file @
84439caf
...
...
@@ -4,6 +4,14 @@ from ignite.metrics import RootMeanSquaredError
import
torch
def
num_atoms
(
input
):
ret
=
[]
for
s
,
c
in
zip
(
input
[
'species'
],
input
[
'coordinates'
]):
ret
.
append
(
torch
.
full
((
c
.
shape
[
0
],),
len
(
s
),
dtype
=
c
.
dtype
,
device
=
c
.
device
))
return
torch
.
cat
(
ret
)
class
DictLoss
(
_Loss
):
def
__init__
(
self
,
key
,
loss
):
...
...
@@ -33,5 +41,9 @@ class DictMetric(Metric):
return
self
.
metric
.
compute
()
energy_mse_loss
=
DictLoss
(
'energies'
,
torch
.
nn
.
MSELoss
())
energy_rmse_metric
=
DictMetric
(
'energies'
,
RootMeanSquaredError
())
def
MSELoss
(
key
):
return
DictLoss
(
key
,
torch
.
nn
.
MSELoss
())
def
RMSEMetric
(
key
):
return
DictMetric
(
key
,
RootMeanSquaredError
())
torchani/models/__init__.py
View file @
84439caf
from
.custom
import
CustomModel
from
.neurochem_nnp
import
NeuroChemNNP
from
.batch
import
BatchModel
__all__
=
[
'CustomModel'
,
'NeuroChemNNP'
,
'BatchModel'
]
__all__
=
[
'CustomModel'
,
'NeuroChemNNP'
]
torchani/models/ani_model.py
View file @
84439caf
...
...
@@ -85,4 +85,4 @@ class ANIModel(BenchmarkedModule):
per_species_outputs
=
torch
.
cat
(
per_species_outputs
,
dim
=
1
)
molecule_output
=
self
.
reducer
(
per_species_outputs
,
dim
=
1
)
return
molecule_output
return
species
,
molecule_output
torchani/models/batch.py
deleted
100644 → 0
View file @
1d8bba37
import
torch
class
BatchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
(
BatchModel
,
self
).
__init__
()
self
.
model
=
model
def
forward
(
self
,
batch
):
results
=
[]
for
i
in
batch
:
coordinates
=
i
[
'coordinates'
]
species
=
i
[
'species'
]
results
.
append
(
self
.
model
((
species
,
coordinates
)))
return
torch
.
cat
(
results
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment