Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
55419dfa
Unverified
Commit
55419dfa
authored
Mar 06, 2019
by
Gao, Xiang
Committed by
GitHub
Mar 06, 2019
Browse files
JIT more (#180)
parent
e0411f49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
97 additions
and
81 deletions
+97
-81
tools/training-benchmark.py
tools/training-benchmark.py
+0
-17
torchani/aev.py
torchani/aev.py
+93
-61
torchani/ignite.py
torchani/ignite.py
+1
-1
torchani/utils.py
torchani/utils.py
+3
-2
No files found.
tools/training-benchmark.py
View file @
55419dfa
...
@@ -92,17 +92,6 @@ def time_func(key, func):
...
@@ -92,17 +92,6 @@ def time_func(key, func):
# enable timers
# enable timers
torchani
.
aev
.
_radial_subaev_terms
=
time_func
(
'radial terms'
,
torchani
.
aev
.
_radial_subaev_terms
)
torchani
.
aev
.
_angular_subaev_terms
=
time_func
(
'angular terms'
,
torchani
.
aev
.
_angular_subaev_terms
)
nnp
[
0
].
_terms_and_indices
=
time_func
(
'terms and indices'
,
nnp
[
0
].
_terms_and_indices
)
torchani
.
aev
.
_compute_mask_r
=
time_func
(
'mask_r'
,
torchani
.
aev
.
_compute_mask_r
)
torchani
.
aev
.
_compute_mask_a
=
time_func
(
'mask_a'
,
torchani
.
aev
.
_compute_mask_a
)
torchani
.
aev
.
_assemble
=
time_func
(
'assemble'
,
torchani
.
aev
.
_assemble
)
nnp
[
0
].
forward
=
time_func
(
'total'
,
nnp
[
0
].
forward
)
nnp
[
0
].
forward
=
time_func
(
'total'
,
nnp
[
0
].
forward
)
nnp
[
1
].
forward
=
time_func
(
'forward'
,
nnp
[
1
].
forward
)
nnp
[
1
].
forward
=
time_func
(
'forward'
,
nnp
[
1
].
forward
)
...
@@ -110,12 +99,6 @@ nnp[1].forward = time_func('forward', nnp[1].forward)
...
@@ -110,12 +99,6 @@ nnp[1].forward = time_func('forward', nnp[1].forward)
start
=
timeit
.
default_timer
()
start
=
timeit
.
default_timer
()
trainer
.
run
(
dataset
,
max_epochs
=
1
)
trainer
.
run
(
dataset
,
max_epochs
=
1
)
elapsed
=
round
(
timeit
.
default_timer
()
-
start
,
2
)
elapsed
=
round
(
timeit
.
default_timer
()
-
start
,
2
)
print
(
'Radial terms:'
,
timers
[
'radial terms'
])
print
(
'Angular terms:'
,
timers
[
'angular terms'
])
print
(
'Terms and indices:'
,
timers
[
'terms and indices'
])
print
(
'Mask R:'
,
timers
[
'mask_r'
])
print
(
'Mask A:'
,
timers
[
'mask_a'
])
print
(
'Assemble:'
,
timers
[
'assemble'
])
print
(
'Total AEV:'
,
timers
[
'total'
])
print
(
'Total AEV:'
,
timers
[
'total'
])
print
(
'NN:'
,
timers
[
'forward'
])
print
(
'NN:'
,
timers
[
'forward'
])
print
(
'Epoch time:'
,
elapsed
)
print
(
'Epoch time:'
,
elapsed
)
torchani/aev.py
View file @
55419dfa
...
@@ -87,6 +87,38 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
...
@@ -87,6 +87,38 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return
ret
.
flatten
(
start_dim
=-
4
)
return
ret
.
flatten
(
start_dim
=-
4
)
@
torch
.
jit
.
script
def
_combinations
(
tensor
,
dim
=
0
):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n
=
tensor
.
shape
[
dim
]
if
n
==
0
:
return
tensor
,
tensor
r
=
torch
.
arange
(
n
,
dtype
=
torch
.
long
,
device
=
tensor
.
device
)
index1
,
index2
=
torch
.
combinations
(
r
).
unbind
(
-
1
)
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
@
torch
.
jit
.
script
def
_terms_and_indices
(
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
distances
,
vec
):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
# type: (float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
radial_terms
=
_radial_subaev_terms
(
Rcr
,
EtaR
,
ShfR
,
distances
)
vec
=
_combinations
(
vec
,
-
2
)
angular_terms
=
_angular_subaev_terms
(
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
*
vec
)
return
radial_terms
,
angular_terms
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
default_neighborlist
(
species
,
coordinates
,
cutoff
):
def
default_neighborlist
(
species
,
coordinates
,
cutoff
):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
...
@@ -121,24 +153,13 @@ def default_neighborlist(species, coordinates, cutoff):
...
@@ -121,24 +153,13 @@ def default_neighborlist(species, coordinates, cutoff):
return
neighbor_species
,
neighbor_distances
,
neighbor_coordinates
return
neighbor_species
,
neighbor_distances
,
neighbor_coordinates
# @torch.jit.script
@
torch
.
jit
.
script
def
_combinations
(
tensor
,
dim
=
0
):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n
=
tensor
.
shape
[
dim
]
if
n
==
0
:
return
tensor
,
tensor
r
=
torch
.
arange
(
n
,
device
=
tensor
.
device
)
index1
,
index2
=
torch
.
combinations
(
r
).
unbind
(
-
1
)
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
# @torch.jit.script
def
_compute_mask_r
(
species_r
,
num_species
):
def
_compute_mask_r
(
species_r
,
num_species
):
# type: (Tensor, int) -> Tensor
# type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices"""
"""Get mask of radial terms for each supported species from indices"""
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
num_species
,
device
=
species_r
.
device
))
torch
.
arange
(
num_species
,
dtype
=
torch
.
long
,
device
=
species_r
.
device
))
return
mask_r
return
mask_r
...
@@ -154,7 +175,7 @@ def _compute_mask_a(species_a, present_species):
...
@@ -154,7 +175,7 @@ def _compute_mask_a(species_a, present_species):
return
mask_a
return
mask_a
#
@torch.jit.script
@
torch
.
jit
.
script
def
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
def
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
,
num_species
,
angular_sublength
):
mask_r
,
mask_a
,
num_species
,
angular_sublength
):
"""Returns radial and angular AEV computed from terms according
"""Returns radial and angular AEV computed from terms according
...
@@ -172,40 +193,69 @@ def _assemble(radial_terms, angular_terms, present_species,
...
@@ -172,40 +193,69 @@ def _assemble(radial_terms, angular_terms, present_species,
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
pairs, present species, present species)
"""
"""
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, int, int) -> Tuple[Tensor, Tensor] # noqa: E501
conformations
=
radial_terms
.
shape
[
0
]
conformations
=
radial_terms
.
shape
[
0
]
atoms
=
radial_terms
.
shape
[
1
]
atoms
=
radial_terms
.
shape
[
1
]
# assemble radial subaev
# assemble radial subaev
present_radial_aevs
=
(
present_radial_aevs
=
(
radial_terms
.
unsqueeze
(
-
2
)
*
radial_terms
.
unsqueeze
(
-
2
)
*
mask_r
.
unsqueeze
(
-
1
).
t
ype
(
radial_terms
.
dtype
)
mask_r
.
unsqueeze
(
-
1
).
t
o
(
radial_terms
.
dtype
)
).
sum
(
-
3
)
).
sum
(
-
3
)
# present_radial_aevs has shape
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
# (conformations, atoms, present species, radial_length)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
radial_aevs
=
present_radial_aevs
.
flatten
(
start_dim
=
2
)
# assemble angular subaev
# assemble angular subaev
rev_indices
=
present_species
.
new_full
((
num_species
,),
-
1
)
rev_indices
=
torch
.
full
((
num_species
,),
-
1
,
dtype
=
present_species
.
dtype
,
device
=
present_species
.
device
)
rev_indices
[
present_species
]
=
torch
.
arange
(
present_species
.
numel
(),
rev_indices
[
present_species
]
=
torch
.
arange
(
present_species
.
numel
(),
dtype
=
torch
.
long
,
device
=
radial_terms
.
device
)
device
=
radial_terms
.
device
)
angular_aevs
=
[]
angular_aevs
=
[]
zero_angular_subaev
=
radial_terms
.
new_zeros
(
zero_angular_subaev
=
torch
.
zeros
(
conformations
,
atoms
,
angular_sublength
,
conformations
,
atoms
,
angular_sublength
)
dtype
=
radial_terms
.
dtype
,
for
s1
,
s2
in
torch
.
combinations
(
device
=
radial_terms
.
device
)
torch
.
arange
(
num_species
,
device
=
radial_terms
.
device
),
for
s1
in
range
(
num_species
):
2
,
with_replacement
=
True
):
# TODO: make PyTorch support range(start, end) and
i1
=
rev_indices
[
s1
].
item
()
# range(start, end, step) and remove the workaround
i2
=
rev_indices
[
s2
].
item
()
# below. The inner for loop should be:
if
i1
>=
0
and
i2
>=
0
:
# for s2 in range(s1, num_species):
mask
=
mask_a
[...,
i1
,
i2
].
unsqueeze
(
-
1
).
type
(
radial_terms
.
dtype
)
for
s2
in
range
(
num_species
-
s1
):
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
s2
+=
s1
else
:
i1
=
int
(
rev_indices
[
s1
])
subaev
=
zero_angular_subaev
i2
=
int
(
rev_indices
[
s2
])
angular_aevs
.
append
(
subaev
)
if
i1
>=
0
and
i2
>=
0
:
mask
=
mask_a
[:,
:,
:,
i1
,
i2
].
unsqueeze
(
-
1
)
\
.
to
(
radial_terms
.
dtype
)
subaev
=
(
angular_terms
*
mask
).
sum
(
-
2
)
else
:
subaev
=
zero_angular_subaev
angular_aevs
.
append
(
subaev
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
return
radial_aevs
,
torch
.
cat
(
angular_aevs
,
dim
=
2
)
@
torch
.
jit
.
script
def
_compute_aev
(
num_species
,
angular_sublength
,
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
species
,
species_
,
distances
,
vec
):
# type: (int, int, float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
present_species
=
utils
.
present_species
(
species
)
radial_terms
,
angular_terms
=
_terms_and_indices
(
Rcr
,
EtaR
,
ShfR
,
Rca
,
ShfZ
,
EtaA
,
Zeta
,
ShfA
,
distances
,
vec
)
mask_r
=
_compute_mask_r
(
species_
,
num_species
)
mask_a
=
_compute_mask_a
(
species_
,
present_species
)
radial
,
angular
=
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
,
num_species
,
angular_sublength
)
fullaev
=
torch
.
cat
([
radial
,
angular
],
dim
=
2
)
return
species
,
fullaev
class
AEVComputer
(
torch
.
nn
.
Module
):
class
AEVComputer
(
torch
.
nn
.
Module
):
r
"""The AEV computer that takes coordinates as input and outputs aevs.
r
"""The AEV computer that takes coordinates as input and outputs aevs.
...
@@ -245,6 +295,9 @@ class AEVComputer(torch.nn.Module):
...
@@ -245,6 +295,9 @@ class AEVComputer(torch.nn.Module):
.. _ANI paper:
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
"""
__constants__
=
[
'Rcr'
,
'Rca'
,
'num_species'
,
'radial_sublength'
,
'radial_length'
,
'angular_sublength'
,
'angular_length'
,
'aev_length'
]
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
,
neighborlist_computer
=
default_neighborlist
):
num_species
,
neighborlist_computer
=
default_neighborlist
):
...
@@ -277,24 +330,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -277,24 +330,7 @@ class AEVComputer(torch.nn.Module):
# The length of full aev
# The length of full aev
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
self
.
aev_length
=
self
.
radial_length
+
self
.
angular_length
def
_terms_and_indices
(
self
,
species
,
coordinates
):
# @torch.jit.script_method
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
max_cutoff
=
max
(
self
.
Rcr
,
self
.
Rca
)
species_
,
distances
,
vec
=
self
.
neighborlist
(
species
,
coordinates
,
max_cutoff
)
radial_terms
=
_radial_subaev_terms
(
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
distances
)
vec
=
_combinations
(
vec
,
-
2
)
angular_terms
=
_angular_subaev_terms
(
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
,
*
vec
)
return
radial_terms
,
angular_terms
,
species_
def
forward
(
self
,
species_coordinates
):
def
forward
(
self
,
species_coordinates
):
"""Compute AEVs
"""Compute AEVs
...
@@ -309,17 +345,13 @@ class AEVComputer(torch.nn.Module):
...
@@ -309,17 +345,13 @@ class AEVComputer(torch.nn.Module):
unchanged, and AEVs is a tensor of shape
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
``(C, A, self.aev_length())``
"""
"""
species
,
coordinates
=
species_coordinates
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
present_species
=
utils
.
present_species
(
species
)
species
,
coordinates
=
species_coordinates
max_cutoff
=
max
(
self
.
Rcr
,
self
.
Rca
)
radial_terms
,
angular_terms
,
species_
=
\
species_
,
distances
,
vec
=
self
.
neighborlist
(
species
,
coordinates
,
self
.
_terms_and_indices
(
species
,
coordinates
)
max_cutoff
)
mask_r
=
_compute_mask_r
(
species_
,
self
.
num_species
)
return
_compute_aev
(
mask_a
=
_compute_mask_a
(
species_
,
present_species
)
self
.
num_species
,
self
.
angular_sublength
,
self
.
Rcr
,
self
.
EtaR
,
self
.
ShfR
,
self
.
Rca
,
self
.
ShfZ
,
self
.
EtaA
,
self
.
Zeta
,
self
.
ShfA
,
radial
,
angular
=
_assemble
(
radial_terms
,
angular_terms
,
species
,
species_
,
distances
,
vec
)
present_species
,
mask_r
,
mask_a
,
self
.
num_species
,
self
.
angular_sublength
)
fullaev
=
torch
.
cat
([
radial
,
angular
],
dim
=
2
)
return
species
,
fullaev
torchani/ignite.py
View file @
55419dfa
...
@@ -30,7 +30,7 @@ class Container(torch.nn.ModuleDict):
...
@@ -30,7 +30,7 @@ class Container(torch.nn.ModuleDict):
results
=
{
k
:
[]
for
k
in
self
}
results
=
{
k
:
[]
for
k
in
self
}
for
sx
in
species_x
:
for
sx
in
species_x
:
for
k
in
self
:
for
k
in
self
:
_
,
result
=
self
[
k
](
sx
)
_
,
result
=
self
[
k
](
tuple
(
sx
)
)
results
[
k
].
append
(
result
)
results
[
k
].
append
(
result
)
for
k
in
self
:
for
k
in
self
:
results
[
k
]
=
torch
.
cat
(
results
[
k
])
results
[
k
]
=
torch
.
cat
(
results
[
k
])
...
...
torchani/utils.py
View file @
55419dfa
...
@@ -65,6 +65,7 @@ def pad_coordinates(species_coordinates):
...
@@ -65,6 +65,7 @@ def pad_coordinates(species_coordinates):
return
torch
.
cat
(
species
),
torch
.
cat
(
coordinates
)
return
torch
.
cat
(
species
),
torch
.
cat
(
coordinates
)
@
torch
.
jit
.
script
def
present_species
(
species
):
def
present_species
(
species
):
"""Given a vector of species of atoms, compute the unique species present.
"""Given a vector of species of atoms, compute the unique species present.
...
@@ -74,8 +75,8 @@ def present_species(species):
...
@@ -74,8 +75,8 @@ def present_species(species):
Returns:
Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted.
:class:`torch.Tensor`: 1D vector storing present atom types sorted.
"""
"""
present_species
=
species
.
flatten
().
unique
(
sorted
=
True
)
present_species
,
_
=
species
.
flatten
().
_
unique
(
sorted
=
True
)
if
present_species
[
0
]
.
item
(
)
==
-
1
:
if
int
(
present_species
[
0
])
==
-
1
:
present_species
=
present_species
[
1
:]
present_species
=
present_species
[
1
:]
return
present_species
return
present_species
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment