Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
e0411f49
Unverified
Commit
e0411f49
authored
Mar 04, 2019
by
Gao, Xiang
Committed by
GitHub
Mar 04, 2019
Browse files
Move more computation ouside AEVComputer (#179)
parent
65bcbb45
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
110 deletions
+111
-110
tests/test_aev.py
tests/test_aev.py
+1
-1
tests/test_neurochem.py
tests/test_neurochem.py
+1
-1
tools/training-benchmark.py
tools/training-benchmark.py
+5
-5
torchani/aev.py
torchani/aev.py
+103
-102
torchani/neurochem/__init__.py
torchani/neurochem/__init__.py
+1
-1
No files found.
tests/test_aev.py
View file @
e0411f49
...
...
@@ -14,7 +14,7 @@ class TestAEV(unittest.TestCase):
def
setUp
(
self
):
builtins
=
torchani
.
neurochem
.
Builtins
()
self
.
aev_computer
=
builtins
.
aev_computer
self
.
radial_length
=
self
.
aev_computer
.
radial_length
()
self
.
radial_length
=
self
.
aev_computer
.
radial_length
self
.
tolerance
=
1e-5
def
random_skip
(
self
):
...
...
tests/test_neurochem.py
View file @
e0411f49
...
...
@@ -16,7 +16,7 @@ class TestNeuroChem(unittest.TestCase):
trainer
=
torchani
.
neurochem
.
Trainer
(
iptpath
,
d
,
True
,
'runs'
)
# test if loader construct correct model
self
.
assertEqual
(
trainer
.
aev_computer
.
aev_length
()
,
384
)
self
.
assertEqual
(
trainer
.
aev_computer
.
aev_length
,
384
)
m
=
trainer
.
model
H
,
C
,
N
,
O
=
m
# noqa: E741
self
.
assertIsInstance
(
H
[
0
],
torch
.
nn
.
Linear
)
...
...
tools/training-benchmark.py
View file @
e0411f49
...
...
@@ -98,10 +98,11 @@ torchani.aev._angular_subaev_terms = time_func(
'angular terms'
,
torchani
.
aev
.
_angular_subaev_terms
)
nnp
[
0
].
_terms_and_indices
=
time_func
(
'terms and indices'
,
nnp
[
0
].
_terms_and_indices
)
nnp
[
0
].
_combinations
=
time_func
(
'combinations'
,
nnp
[
0
].
_combinations
)
nnp
[
0
].
_compute_mask_r
=
time_func
(
'mask_r'
,
nnp
[
0
].
_compute_mask_r
)
nnp
[
0
].
_compute_mask_a
=
time_func
(
'mask_a'
,
nnp
[
0
].
_compute_mask_a
)
nnp
[
0
].
_assemble
=
time_func
(
'assemble'
,
nnp
[
0
].
_assemble
)
torchani
.
aev
.
_compute_mask_r
=
time_func
(
'mask_r'
,
torchani
.
aev
.
_compute_mask_r
)
torchani
.
aev
.
_compute_mask_a
=
time_func
(
'mask_a'
,
torchani
.
aev
.
_compute_mask_a
)
torchani
.
aev
.
_assemble
=
time_func
(
'assemble'
,
torchani
.
aev
.
_assemble
)
nnp
[
0
].
forward
=
time_func
(
'total'
,
nnp
[
0
].
forward
)
nnp
[
1
].
forward
=
time_func
(
'forward'
,
nnp
[
1
].
forward
)
...
...
@@ -112,7 +113,6 @@ elapsed = round(timeit.default_timer() - start, 2)
print
(
'Radial terms:'
,
timers
[
'radial terms'
])
print
(
'Angular terms:'
,
timers
[
'angular terms'
])
print
(
'Terms and indices:'
,
timers
[
'terms and indices'
])
print
(
'Combinations:'
,
timers
[
'combinations'
])
print
(
'Mask R:'
,
timers
[
'mask_r'
])
print
(
'Mask A:'
,
timers
[
'mask_a'
])
print
(
'Assemble:'
,
timers
[
'assemble'
])
...
...
torchani/aev.py
View file @
e0411f49
...
...
@@ -121,6 +121,91 @@ def default_neighborlist(species, coordinates, cutoff):
return
neighbor_species
,
neighbor_distances
,
neighbor_coordinates
# @torch.jit.script
def
_combinations
(
tensor
,
dim
=
0
):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n
=
tensor
.
shape
[
dim
]
if
n
==
0
:
return
tensor
,
tensor
r
=
torch
.
arange
(
n
,
device
=
tensor
.
device
)
index1
,
index2
=
torch
.
combinations
(
r
).
unbind
(
-
1
)
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
# @torch.jit.script
def
_compute_mask_r
(
species_r
,
num_species
):
# type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
num_species
,
device
=
species_r
.
device
))
return
mask_r
@
torch
.
jit
.
script
def
_compute_mask_a
(
species_a
,
present_species
):
"""Get mask of angular terms for each supported species from indices"""
species_a1
,
species_a2
=
_combinations
(
species_a
,
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
mask
=
mask_a1
&
mask_a2
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
mask
|
mask_rev
return
mask_a
# @torch.jit.script
def
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
,
num_species
,
angular_sublength
):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
conformations
=
radial_terms
.
shape
[
0
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
# assemble angular subaev
rev_indices
=
present_species
.
new_full
((
num_species
,),
-
1
)
rev_indices
[
present_species
]
=
torch
.
arange
(
present_species
.
numel
(),
device
=
radial_terms
.
device
)
angular_aevs
=
[]
zero_angular_subaev
=
radial_terms
.
new_zeros
(
conformations
,
atoms
,
angular_sublength
)
for
s1
,
s2
in
torch
.
combinations
(
torch
.
arange
(
num_species
,
device
=
radial_terms
.
device
),
2
,
with_replacement
=
True
):
i1
=
rev_indices
[
s1
].
item
()
i2
=
rev_indices
[
s2
].
item
()
if
i1
>=
0
and
i2
>=
0
:
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
subaev
=
zero_angular_subaev
angular_aevs
.
append
(
subaev
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
class
AEVComputer
(
torch
.
nn
.
Module
):
r
"""The AEV computer that takes coordinates as input and outputs aevs.
...
...
@@ -179,27 +264,18 @@ class AEVComputer(torch.nn.Module):
self
.
num_species
=
num_species
self
.
neighborlist
=
neighborlist_computer
def
radial_sublength
(
self
):
"""Returns the length of radial subaev of a single species"""
return
self
.
EtaR
.
numel
()
*
self
.
ShfR
.
numel
()
def
radial_length
(
self
):
"""Returns the length of full radial aev"""
return
self
.
num_species
*
self
.
radial_sublength
()
def
angular_sublength
(
self
):
"""Returns the length of angular subaev of a single species"""
return
self
.
EtaA
.
numel
()
*
self
.
Zeta
.
numel
()
*
self
.
ShfA
.
numel
()
*
\
self
.
ShfZ
.
numel
()
def
angular_length
(
self
):
"""Returns the length of full angular aev"""
s
=
self
.
num_species
return
(
s
*
(
s
+
1
))
//
2
*
self
.
angular_sublength
()
def
aev_length
(
self
):
"""Returns the length of full aev"""
return
self
.
radial_length
()
+
self
.
angular_length
()
# The length of radial subaev of a single species
self
.
radial_sublength
=
self
.
EtaR
.
numel
()
*
self
.
ShfR
.
numel
()
# The length of full radial aev
self
.
radial_length
=
self
.
num_species
*
self
.
radial_sublength
# The length of angular subaev of a single species
self
.
angular_sublength
=
self
.
EtaA
.
numel
()
*
self
.
Zeta
.
numel
()
*
\
self
.
ShfA
.
numel
()
*
self
.
ShfZ
.
numel
()
# The length of full angular aev
self
.
angular_length
=
(
self
.
num_species
*
(
self
.
num_species
+
1
))
\
//
2
*
self
.
angular_sublength
# The length of full aev
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
def
_terms_and_indices
(
self
,
species
,
coordinates
):
"""Returns radial and angular subAEV terms, these terms will be sorted
...
...
@@ -213,88 +289,12 @@ class AEVComputer(torch.nn.Module):
radial_terms
=
_radial_subaev_terms
(
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
distances
)
vec
=
self
.
_combinations
(
vec
,
-
2
)
vec
=
_combinations
(
vec
,
-
2
)
angular_terms
=
_angular_subaev_terms
(
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
,
*
vec
)
return
radial_terms
,
angular_terms
,
species_
def
_combinations
(
self
,
tensor
,
dim
=
0
):
n
=
tensor
.
shape
[
dim
]
if
n
==
0
:
return
tensor
,
tensor
r
=
torch
.
arange
(
n
,
device
=
tensor
.
device
)
index1
,
index2
=
torch
.
combinations
(
r
).
unbind
(
-
1
)
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
def
_compute_mask_r
(
self
,
species_r
):
"""Get mask of radial terms for each supported species from indices"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
self
.
num_species
,
device
=
self
.
EtaR
.
device
))
return
mask_r
def
_compute_mask_a
(
self
,
species_a
,
present_species
):
"""Get mask of angular terms for each supported species from indices"""
species_a1
,
species_a2
=
self
.
_combinations
(
species_a
,
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
mask
=
mask_a1
&
mask_a2
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
mask
|
mask_rev
return
mask_a
def
_assemble
(
self
,
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
conformations
=
radial_terms
.
shape
[
0
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
# assemble angular subaev
rev_indices
=
self
.
EtaR
.
new_full
((
self
.
num_species
,),
-
1
,
dtype
=
torch
.
int64
)
rev_indices
[
present_species
]
=
torch
.
arange
(
present_species
.
numel
(),
device
=
self
.
EtaR
.
device
)
angular_aevs
=
[]
zero_angular_subaev
=
self
.
EtaR
.
new_zeros
(
conformations
,
atoms
,
self
.
angular_sublength
())
for
s1
,
s2
in
torch
.
combinations
(
torch
.
arange
(
self
.
num_species
,
device
=
self
.
EtaR
.
device
),
2
,
with_replacement
=
True
):
i1
=
rev_indices
[
s1
].
item
()
i2
=
rev_indices
[
s2
].
item
()
if
i1
>=
0
and
i2
>=
0
:
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
EtaR
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
subaev
=
zero_angular_subaev
angular_aevs
.
append
(
subaev
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
def
forward
(
self
,
species_coordinates
):
"""Compute AEVs
...
...
@@ -315,10 +315,11 @@ class AEVComputer(torch.nn.Module):
radial_terms
,
angular_terms
,
species_
=
\
self
.
_terms_and_indices
(
species
,
coordinates
)
mask_r
=
self
.
_compute_mask_r
(
species_
)
mask_a
=
self
.
_compute_mask_a
(
species_
,
present_species
)
mask_r
=
_compute_mask_r
(
species_
,
self
.
num_species
)
mask_a
=
_compute_mask_a
(
species_
,
present_species
)
radial
,
angular
=
self
.
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
)
radial
,
angular
=
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
,
self
.
num_species
,
self
.
angular_sublength
)
fullaev
=
torch
.
cat
([
radial
,
angular
],
dim
=
2
)
return
species
,
fullaev
torchani/neurochem/__init__.py
View file @
e0411f49
...
...
@@ -588,7 +588,7 @@ if sys.version_info[0] > 2:
# construct networks
input_size
,
network_setup
=
network_setup
if
input_size
!=
self
.
aev_computer
.
aev_length
()
:
if
input_size
!=
self
.
aev_computer
.
aev_length
:
raise
ValueError
(
'AEV size and input size does not match'
)
l2reg
=
[]
atomic_nets
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment