Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
e41c2d93
Commit
e41c2d93
authored
Nov 13, 2019
by
Gao, Xiang
Committed by
Farhad Ramezanghorbani
Nov 13, 2019
Browse files
Subclass ModuleList to simplify code (#385)
parent
93372134
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
19 deletions
+10
-19
torchani/nn.py
torchani/nn.py
+10
-19
No files found.
torchani/nn.py
View file @
e41c2d93
...
@@ -8,7 +8,7 @@ class SpeciesEnergies(NamedTuple):
...
@@ -8,7 +8,7 @@ class SpeciesEnergies(NamedTuple):
energies
:
Tensor
energies
:
Tensor
class
ANIModel
(
torch
.
nn
.
Module
):
class
ANIModel
(
torch
.
nn
.
Module
List
):
"""ANI model that compute energies from species and AEVs.
"""ANI model that compute energies from species and AEVs.
Different atom types might have different modules, when computing
Different atom types might have different modules, when computing
...
@@ -27,11 +27,7 @@ class ANIModel(torch.nn.Module):
...
@@ -27,11 +27,7 @@ class ANIModel(torch.nn.Module):
"""
"""
def
__init__
(
self
,
modules
):
def
__init__
(
self
,
modules
):
super
(
ANIModel
,
self
).
__init__
()
super
(
ANIModel
,
self
).
__init__
(
modules
)
self
.
module_list
=
torch
.
nn
.
ModuleList
(
modules
)
def
__getitem__
(
self
,
i
):
return
self
.
module_list
[
i
]
def
forward
(
self
,
species_aev
:
Tuple
[
Tensor
,
Tensor
],
def
forward
(
self
,
species_aev
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
cell
:
Optional
[
Tensor
]
=
None
,
...
@@ -44,7 +40,7 @@ class ANIModel(torch.nn.Module):
...
@@ -44,7 +40,7 @@ class ANIModel(torch.nn.Module):
output
=
aev
.
new_zeros
(
species_
.
shape
)
output
=
aev
.
new_zeros
(
species_
.
shape
)
for
i
,
m
in
enumerate
(
self
.
module_list
):
for
i
,
m
in
enumerate
(
self
):
mask
=
(
species_
==
i
)
mask
=
(
species_
==
i
)
midx
=
mask
.
nonzero
().
flatten
()
midx
=
mask
.
nonzero
().
flatten
()
if
midx
.
shape
[
0
]
>
0
:
if
midx
.
shape
[
0
]
>
0
:
...
@@ -54,13 +50,12 @@ class ANIModel(torch.nn.Module):
...
@@ -54,13 +50,12 @@ class ANIModel(torch.nn.Module):
return
SpeciesEnergies
(
species
,
torch
.
sum
(
output
,
dim
=
1
))
return
SpeciesEnergies
(
species
,
torch
.
sum
(
output
,
dim
=
1
))
class
Ensemble
(
torch
.
nn
.
Module
):
class
Ensemble
(
torch
.
nn
.
Module
List
):
"""Compute the average output of an ensemble of modules."""
"""Compute the average output of an ensemble of modules."""
def
__init__
(
self
,
modules
):
def
__init__
(
self
,
modules
):
super
(
Ensemble
,
self
).
__init__
()
super
(
Ensemble
,
self
).
__init__
(
modules
)
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
self
.
size
=
len
(
modules
)
self
.
size
=
len
(
self
.
modules_list
)
def
forward
(
self
,
species_input
:
Tuple
[
Tensor
,
Tensor
],
def
forward
(
self
,
species_input
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
cell
:
Optional
[
Tensor
]
=
None
,
...
@@ -68,26 +63,22 @@ class Ensemble(torch.nn.Module):
...
@@ -68,26 +63,22 @@ class Ensemble(torch.nn.Module):
assert
cell
is
None
assert
cell
is
None
assert
pbc
is
None
assert
pbc
is
None
sum_
=
0
sum_
=
0
for
x
in
self
.
modules_list
:
for
x
in
self
:
sum_
+=
x
(
species_input
)[
1
]
sum_
+=
x
(
species_input
)[
1
]
species
,
_
=
species_input
species
,
_
=
species_input
return
SpeciesEnergies
(
species
,
sum_
/
self
.
size
)
return
SpeciesEnergies
(
species
,
sum_
/
self
.
size
)
def
__getitem__
(
self
,
i
):
return
self
.
modules_list
[
i
]
class
Sequential
(
torch
.
nn
.
Module
):
class
Sequential
(
torch
.
nn
.
Module
List
):
"""Modified Sequential module that accept Tuple type as input"""
"""Modified Sequential module that accept Tuple type as input"""
def
__init__
(
self
,
*
modules
):
def
__init__
(
self
,
*
modules
):
super
(
Sequential
,
self
).
__init__
()
super
(
Sequential
,
self
).
__init__
(
modules
)
self
.
modules_list
=
torch
.
nn
.
ModuleList
(
modules
)
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
],
def
forward
(
self
,
input_
:
Tuple
[
Tensor
,
Tensor
],
cell
:
Optional
[
Tensor
]
=
None
,
cell
:
Optional
[
Tensor
]
=
None
,
pbc
:
Optional
[
Tensor
]
=
None
):
pbc
:
Optional
[
Tensor
]
=
None
):
for
module
in
self
.
modules_list
:
for
module
in
self
:
input_
=
module
(
input_
,
cell
=
cell
,
pbc
=
pbc
)
input_
=
module
(
input_
,
cell
=
cell
,
pbc
=
pbc
)
cell
=
None
cell
=
None
pbc
=
None
pbc
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment