Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
1055f1f5
Commit
1055f1f5
authored
Nov 20, 2019
by
Gao, Xiang
Committed by
Farhad Ramezanghorbani
Nov 20, 2019
Browse files
Add element names to ANIModel (#398)
* Add element names to ANIModel * nc trainer
parent
66c3743c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
8 deletions
+25
-8
tests/test_neurochem.py
tests/test_neurochem.py
+4
-1
torchani/neurochem/__init__.py
torchani/neurochem/__init__.py
+6
-5
torchani/nn.py
torchani/nn.py
+15
-2
No files found.
tests/test_neurochem.py
View file @
1055f1f5
...
...
@@ -18,7 +18,10 @@ class TestNeuroChem(unittest.TestCase):
# test if loader construct correct model
self
.
assertEqual
(
trainer
.
aev_computer
.
aev_length
,
384
)
m
=
trainer
.
nn
H
,
C
,
N
,
O
=
m
# noqa: E741
H
=
m
[
'H'
]
C
=
m
[
'C'
]
N
=
m
[
'N'
]
O
=
m
[
'O'
]
# noqa: E741
self
.
assertIsInstance
(
H
[
0
],
torch
.
nn
.
Linear
)
self
.
assertListEqual
(
list
(
H
[
0
].
weight
.
shape
),
[
160
,
384
])
self
.
assertIsInstance
(
H
[
1
],
torch
.
nn
.
CELU
)
...
...
torchani/neurochem/__init__.py
View file @
1055f1f5
...
...
@@ -15,6 +15,7 @@ from ..nn import ANIModel, Ensemble, Gaussian, Sequential
from
..utils
import
EnergyShifter
,
ChemicalSymbolsToInts
from
..aev
import
AEVComputer
from
..optim
import
AdamW
from
collections
import
OrderedDict
class
Constants
(
collections
.
abc
.
Mapping
):
...
...
@@ -240,10 +241,10 @@ def load_model(species, dir_):
chemical symbols of each supported atom type in correct order.
dir_ (str): String for directory storing network configurations.
"""
models
=
[]
models
=
OrderedDict
()
for
i
in
species
:
filename
=
os
.
path
.
join
(
dir_
,
'ANN-{}.nnf'
.
format
(
i
))
models
.
append
(
load_atomic_network
(
filename
)
)
models
[
i
]
=
load_atomic_network
(
filename
)
return
ANIModel
(
models
)
...
...
@@ -496,8 +497,8 @@ if sys.version_info[0] > 2:
input_size
,
network_setup
=
network_setup
if
input_size
!=
self
.
aev_computer
.
aev_length
:
raise
ValueError
(
'AEV size and input size does not match'
)
atomic_nets
=
{}
for
atom_type
in
network_setup
:
atomic_nets
=
OrderedDict
()
for
atom_type
in
self
.
consts
.
species
:
layers
=
network_setup
[
atom_type
]
modules
=
[]
i
=
input_size
...
...
@@ -537,7 +538,7 @@ if sys.version_info[0] > 2:
'unrecognized parameter in layer setup'
)
i
=
o
atomic_nets
[
atom_type
]
=
torch
.
nn
.
Sequential
(
*
modules
)
self
.
nn
=
ANIModel
(
[
atomic_nets
[
s
]
for
s
in
self
.
consts
.
species
]
)
self
.
nn
=
ANIModel
(
atomic_nets
)
# initialize weights and biases
self
.
nn
.
apply
(
init_params
)
...
...
torchani/nn.py
View file @
1055f1f5
import
torch
from
collections
import
OrderedDict
from
torch
import
Tensor
from
typing
import
Tuple
,
NamedTuple
,
Optional
...
...
@@ -13,7 +14,7 @@ class SpeciesCoordinates(NamedTuple):
coordinates
:
Tensor
class
ANIModel
(
torch
.
nn
.
Module
Lis
t
):
class
ANIModel
(
torch
.
nn
.
Module
Dic
t
):
"""ANI model that compute energies from species and AEVs.
Different atom types might have different modules, when computing
...
...
@@ -31,6 +32,18 @@ class ANIModel(torch.nn.ModuleList):
module by putting the same reference in :attr:`modules`.
"""
@
staticmethod
def
ensureOrderedDict
(
modules
):
if
isinstance
(
modules
,
OrderedDict
):
return
modules
od
=
OrderedDict
()
for
i
,
m
in
enumerate
(
modules
):
od
[
str
(
i
)]
=
m
return
od
def
__init__
(
self
,
modules
):
super
(
ANIModel
,
self
).
__init__
(
self
.
ensureOrderedDict
(
modules
))
def
forward
(
self
,
species_aev
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
pbc
:
Optional
[
Tensor
]
=
None
)
->
SpeciesEnergies
:
...
...
@@ -42,7 +55,7 @@ class ANIModel(torch.nn.ModuleList):
output
=
aev
.
new_zeros
(
species_
.
shape
)
for
i
,
m
in
enumerate
(
self
):
for
i
,
(
_
,
m
)
in
enumerate
(
self
.
items
()
):
mask
=
(
species_
==
i
)
midx
=
mask
.
nonzero
().
flatten
()
if
midx
.
shape
[
0
]
>
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment