Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
e0411f49
"vscode:/vscode.git/clone" did not exist on "4ad999d1440e896abec3f3c7029f292ce46cc820"
Unverified
Commit
e0411f49
authored
Mar 04, 2019
by
Gao, Xiang
Committed by
GitHub
Mar 04, 2019
Browse files
Move more computation ouside AEVComputer (#179)
parent
65bcbb45
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
111 additions
and
110 deletions
+111
-110
tests/test_aev.py
tests/test_aev.py
+1
-1
tests/test_neurochem.py
tests/test_neurochem.py
+1
-1
tools/training-benchmark.py
tools/training-benchmark.py
+5
-5
torchani/aev.py
torchani/aev.py
+103
-102
torchani/neurochem/__init__.py
torchani/neurochem/__init__.py
+1
-1
No files found.
tests/test_aev.py
View file @
e0411f49
...
@@ -14,7 +14,7 @@ class TestAEV(unittest.TestCase):
...
@@ -14,7 +14,7 @@ class TestAEV(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
builtins
=
torchani
.
neurochem
.
Builtins
()
builtins
=
torchani
.
neurochem
.
Builtins
()
self
.
aev_computer
=
builtins
.
aev_computer
self
.
aev_computer
=
builtins
.
aev_computer
self
.
radial_length
=
self
.
aev_computer
.
radial_length
()
self
.
radial_length
=
self
.
aev_computer
.
radial_length
self
.
tolerance
=
1e-5
self
.
tolerance
=
1e-5
def
random_skip
(
self
):
def
random_skip
(
self
):
...
...
tests/test_neurochem.py
View file @
e0411f49
...
@@ -16,7 +16,7 @@ class TestNeuroChem(unittest.TestCase):
...
@@ -16,7 +16,7 @@ class TestNeuroChem(unittest.TestCase):
trainer
=
torchani
.
neurochem
.
Trainer
(
iptpath
,
d
,
True
,
'runs'
)
trainer
=
torchani
.
neurochem
.
Trainer
(
iptpath
,
d
,
True
,
'runs'
)
# test if loader construct correct model
# test if loader construct correct model
self
.
assertEqual
(
trainer
.
aev_computer
.
aev_length
()
,
384
)
self
.
assertEqual
(
trainer
.
aev_computer
.
aev_length
,
384
)
m
=
trainer
.
model
m
=
trainer
.
model
H
,
C
,
N
,
O
=
m
# noqa: E741
H
,
C
,
N
,
O
=
m
# noqa: E741
self
.
assertIsInstance
(
H
[
0
],
torch
.
nn
.
Linear
)
self
.
assertIsInstance
(
H
[
0
],
torch
.
nn
.
Linear
)
...
...
tools/training-benchmark.py
View file @
e0411f49
...
@@ -98,10 +98,11 @@ torchani.aev._angular_subaev_terms = time_func(
...
@@ -98,10 +98,11 @@ torchani.aev._angular_subaev_terms = time_func(
'angular terms'
,
torchani
.
aev
.
_angular_subaev_terms
)
'angular terms'
,
torchani
.
aev
.
_angular_subaev_terms
)
nnp
[
0
].
_terms_and_indices
=
time_func
(
'terms and indices'
,
nnp
[
0
].
_terms_and_indices
=
time_func
(
'terms and indices'
,
nnp
[
0
].
_terms_and_indices
)
nnp
[
0
].
_terms_and_indices
)
nnp
[
0
].
_combinations
=
time_func
(
'combinations'
,
nnp
[
0
].
_combinations
)
torchani
.
aev
.
_compute_mask_r
=
time_func
(
'mask_r'
,
nnp
[
0
].
_compute_mask_r
=
time_func
(
'mask_r'
,
nnp
[
0
].
_compute_mask_r
)
torchani
.
aev
.
_compute_mask_r
)
nnp
[
0
].
_compute_mask_a
=
time_func
(
'mask_a'
,
nnp
[
0
].
_compute_mask_a
)
torchani
.
aev
.
_compute_mask_a
=
time_func
(
'mask_a'
,
nnp
[
0
].
_assemble
=
time_func
(
'assemble'
,
nnp
[
0
].
_assemble
)
torchani
.
aev
.
_compute_mask_a
)
torchani
.
aev
.
_assemble
=
time_func
(
'assemble'
,
torchani
.
aev
.
_assemble
)
nnp
[
0
].
forward
=
time_func
(
'total'
,
nnp
[
0
].
forward
)
nnp
[
0
].
forward
=
time_func
(
'total'
,
nnp
[
0
].
forward
)
nnp
[
1
].
forward
=
time_func
(
'forward'
,
nnp
[
1
].
forward
)
nnp
[
1
].
forward
=
time_func
(
'forward'
,
nnp
[
1
].
forward
)
...
@@ -112,7 +113,6 @@ elapsed = round(timeit.default_timer() - start, 2)
...
@@ -112,7 +113,6 @@ elapsed = round(timeit.default_timer() - start, 2)
print
(
'Radial terms:'
,
timers
[
'radial terms'
])
print
(
'Radial terms:'
,
timers
[
'radial terms'
])
print
(
'Angular terms:'
,
timers
[
'angular terms'
])
print
(
'Angular terms:'
,
timers
[
'angular terms'
])
print
(
'Terms and indices:'
,
timers
[
'terms and indices'
])
print
(
'Terms and indices:'
,
timers
[
'terms and indices'
])
print
(
'Combinations:'
,
timers
[
'combinations'
])
print
(
'Mask R:'
,
timers
[
'mask_r'
])
print
(
'Mask R:'
,
timers
[
'mask_r'
])
print
(
'Mask A:'
,
timers
[
'mask_a'
])
print
(
'Mask A:'
,
timers
[
'mask_a'
])
print
(
'Assemble:'
,
timers
[
'assemble'
])
print
(
'Assemble:'
,
timers
[
'assemble'
])
...
...
torchani/aev.py
View file @
e0411f49
...
@@ -121,6 +121,91 @@ def default_neighborlist(species, coordinates, cutoff):
...
@@ -121,6 +121,91 @@ def default_neighborlist(species, coordinates, cutoff):
return
neighbor_species
,
neighbor_distances
,
neighbor_coordinates
return
neighbor_species
,
neighbor_distances
,
neighbor_coordinates
# @torch.jit.script
def
_combinations
(
tensor
,
dim
=
0
):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n
=
tensor
.
shape
[
dim
]
if
n
==
0
:
return
tensor
,
tensor
r
=
torch
.
arange
(
n
,
device
=
tensor
.
device
)
index1
,
index2
=
torch
.
combinations
(
r
).
unbind
(
-
1
)
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
# @torch.jit.script
def
_compute_mask_r
(
species_r
,
num_species
):
# type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
num_species
,
device
=
species_r
.
device
))
return
mask_r
@
torch
.
jit
.
script
def
_compute_mask_a
(
species_a
,
present_species
):
"""Get mask of angular terms for each supported species from indices"""
species_a1
,
species_a2
=
_combinations
(
species_a
,
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
mask
=
mask_a1
&
mask_a2
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
mask
|
mask_rev
return
mask_a
# @torch.jit.script
def
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
,
num_species
,
angular_sublength
):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
conformations
=
radial_terms
.
shape
[
0
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
# assemble angular subaev
rev_indices
=
present_species
.
new_full
((
num_species
,),
-
1
)
rev_indices
[
present_species
]
=
torch
.
arange
(
present_species
.
numel
(),
device
=
radial_terms
.
device
)
angular_aevs
=
[]
zero_angular_subaev
=
radial_terms
.
new_zeros
(
conformations
,
atoms
,
angular_sublength
)
for
s1
,
s2
in
torch
.
combinations
(
torch
.
arange
(
num_species
,
device
=
radial_terms
.
device
),
2
,
with_replacement
=
True
):
i1
=
rev_indices
[
s1
].
item
()
i2
=
rev_indices
[
s2
].
item
()
if
i1
>=
0
and
i2
>=
0
:
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
subaev
=
zero_angular_subaev
angular_aevs
.
append
(
subaev
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
class
AEVComputer
(
torch
.
nn
.
Module
):
class
AEVComputer
(
torch
.
nn
.
Module
):
r
"""The AEV computer that takes coordinates as input and outputs aevs.
r
"""The AEV computer that takes coordinates as input and outputs aevs.
...
@@ -179,27 +264,18 @@ class AEVComputer(torch.nn.Module):
...
@@ -179,27 +264,18 @@ class AEVComputer(torch.nn.Module):
self
.
num_species
=
num_species
self
.
num_species
=
num_species
self
.
neighborlist
=
neighborlist_computer
self
.
neighborlist
=
neighborlist_computer
def
radial_sublength
(
self
):
# The length of radial subaev of a single species
"""Returns the length of radial subaev of a single species"""
self
.
radial_sublength
=
self
.
EtaR
.
numel
()
*
self
.
ShfR
.
numel
()
return
self
.
EtaR
.
numel
()
*
self
.
ShfR
.
numel
()
# The length of full radial aev
self
.
radial_length
=
self
.
num_species
*
self
.
radial_sublength
def
radial_length
(
self
):
# The length of angular subaev of a single species
"""Returns the length of full radial aev"""
self
.
angular_sublength
=
self
.
EtaA
.
numel
()
*
self
.
Zeta
.
numel
()
*
\
return
self
.
num_species
*
self
.
radial_sublength
()
self
.
ShfA
.
numel
()
*
self
.
ShfZ
.
numel
()
# The length of full angular aev
def
angular_sublength
(
self
):
self
.
angular_length
=
(
self
.
num_species
*
(
self
.
num_species
+
1
))
\
"""Returns the length of angular subaev of a single species"""
//
2
*
self
.
angular_sublength
return
self
.
EtaA
.
numel
()
*
self
.
Zeta
.
numel
()
*
self
.
ShfA
.
numel
()
*
\
# The length of full aev
self
.
ShfZ
.
numel
()
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
def
angular_length
(
self
):
"""Returns the length of full angular aev"""
s
=
self
.
num_species
return
(
s
*
(
s
+
1
))
//
2
*
self
.
angular_sublength
()
def
aev_length
(
self
):
"""Returns the length of full aev"""
return
self
.
radial_length
()
+
self
.
angular_length
()
def
_terms_and_indices
(
self
,
species
,
coordinates
):
def
_terms_and_indices
(
self
,
species
,
coordinates
):
"""Returns radial and angular subAEV terms, these terms will be sorted
"""Returns radial and angular subAEV terms, these terms will be sorted
...
@@ -213,88 +289,12 @@ class AEVComputer(torch.nn.Module):
...
@@ -213,88 +289,12 @@ class AEVComputer(torch.nn.Module):
radial_terms
=
_radial_subaev_terms
(
self
.
Rcr
,
self
.
EtaR
,
radial_terms
=
_radial_subaev_terms
(
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
distances
)
self
.
ShfR
,
distances
)
vec
=
self
.
_combinations
(
vec
,
-
2
)
vec
=
_combinations
(
vec
,
-
2
)
angular_terms
=
_angular_subaev_terms
(
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
angular_terms
=
_angular_subaev_terms
(
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
,
*
vec
)
self
.
Zeta
,
self
.
ShfA
,
*
vec
)
return
radial_terms
,
angular_terms
,
species_
return
radial_terms
,
angular_terms
,
species_
def
_combinations
(
self
,
tensor
,
dim
=
0
):
n
=
tensor
.
shape
[
dim
]
if
n
==
0
:
return
tensor
,
tensor
r
=
torch
.
arange
(
n
,
device
=
tensor
.
device
)
index1
,
index2
=
torch
.
combinations
(
r
).
unbind
(
-
1
)
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
def
_compute_mask_r
(
self
,
species_r
):
"""Get mask of radial terms for each supported species from indices"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
self
.
num_species
,
device
=
self
.
EtaR
.
device
))
return
mask_r
def
_compute_mask_a
(
self
,
species_a
,
present_species
):
"""Get mask of angular terms for each supported species from indices"""
species_a1
,
species_a2
=
self
.
_combinations
(
species_a
,
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
mask
=
mask_a1
&
mask_a2
mask_rev
=
mask
.
permute
(
0
,
1
,
2
,
4
,
3
)
mask_a
=
mask
|
mask_rev
return
mask_a
def
_assemble
(
self
,
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
conformations
=
radial_terms
.
shape
[
0
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
# assemble angular subaev
rev_indices
=
self
.
EtaR
.
new_full
((
self
.
num_species
,),
-
1
,
dtype
=
torch
.
int64
)
rev_indices
[
present_species
]
=
torch
.
arange
(
present_species
.
numel
(),
device
=
self
.
EtaR
.
device
)
angular_aevs
=
[]
zero_angular_subaev
=
self
.
EtaR
.
new_zeros
(
conformations
,
atoms
,
self
.
angular_sublength
())
for
s1
,
s2
in
torch
.
combinations
(
torch
.
arange
(
self
.
num_species
,
device
=
self
.
EtaR
.
device
),
2
,
with_replacement
=
True
):
i1
=
rev_indices
[
s1
].
item
()
i2
=
rev_indices
[
s2
].
item
()
if
i1
>=
0
and
i2
>=
0
:
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
self
.
EtaR
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
subaev
=
zero_angular_subaev
angular_aevs
.
append
(
subaev
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
def
forward
(
self
,
species_coordinates
):
def
forward
(
self
,
species_coordinates
):
"""Compute AEVs
"""Compute AEVs
...
@@ -315,10 +315,11 @@ class AEVComputer(torch.nn.Module):
...
@@ -315,10 +315,11 @@ class AEVComputer(torch.nn.Module):
radial_terms
,
angular_terms
,
species_
=
\
radial_terms
,
angular_terms
,
species_
=
\
self
.
_terms_and_indices
(
species
,
coordinates
)
self
.
_terms_and_indices
(
species
,
coordinates
)
mask_r
=
self
.
_compute_mask_r
(
species_
)
mask_r
=
_compute_mask_r
(
species_
,
self
.
num_species
)
mask_a
=
self
.
_compute_mask_a
(
species_
,
present_species
)
mask_a
=
_compute_mask_a
(
species_
,
present_species
)
radial
,
angular
=
self
.
_assemble
(
radial_terms
,
angular_terms
,
radial
,
angular
=
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
)
present_species
,
mask_r
,
mask_a
,
self
.
num_species
,
self
.
angular_sublength
)
fullaev
=
torch
.
cat
([
radial
,
angular
],
dim
=
2
)
fullaev
=
torch
.
cat
([
radial
,
angular
],
dim
=
2
)
return
species
,
fullaev
return
species
,
fullaev
torchani/neurochem/__init__.py
View file @
e0411f49
...
@@ -588,7 +588,7 @@ if sys.version_info[0] > 2:
...
@@ -588,7 +588,7 @@ if sys.version_info[0] > 2:
# construct networks
# construct networks
input_size
,
network_setup
=
network_setup
input_size
,
network_setup
=
network_setup
if
input_size
!=
self
.
aev_computer
.
aev_length
()
:
if
input_size
!=
self
.
aev_computer
.
aev_length
:
raise
ValueError
(
'AEV size and input size does not match'
)
raise
ValueError
(
'AEV size and input size does not match'
)
l2reg
=
[]
l2reg
=
[]
atomic_nets
=
{}
atomic_nets
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment