Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
torchani
Commits
7c253794
Unverified
Commit
7c253794
authored
Oct 26, 2018
by
Gao, Xiang
Committed by
GitHub
Oct 26, 2018
Browse files
Separate out neighborlist computer (#119)
parent
39137175
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
59 additions
and
44 deletions
+59
-44
torchani/aev.py
torchani/aev.py
+59
-44
No files found.
torchani/aev.py
View file @
7c253794
...
@@ -12,14 +12,46 @@ def _cutoff_cosine(distances, cutoff):
...
@@ -12,14 +12,46 @@ def _cutoff_cosine(distances, cutoff):
)
)
def
default_neighborlist
(
species
,
coordinates
,
cutoff
):
"""Default neighborlist computer"""
vec
=
coordinates
.
unsqueeze
(
2
)
-
coordinates
.
unsqueeze
(
1
)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""
distances
=
vec
.
norm
(
2
,
-
1
)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask
=
(
species
==
-
1
).
unsqueeze
(
1
)
distances
=
distances
.
masked_fill
(
padding_mask
,
math
.
inf
)
distances
,
indices
=
distances
.
sort
(
-
1
)
min_distances
,
_
=
distances
.
flatten
(
end_dim
=
1
).
min
(
0
)
in_cutoff
=
(
min_distances
<=
cutoff
).
nonzero
().
flatten
()[
1
:]
indices
=
indices
.
index_select
(
-
1
,
in_cutoff
)
# TODO: remove this workaround after gather support broadcasting
atoms
=
coordinates
.
shape
[
1
]
species_
=
species
.
unsqueeze
(
1
).
expand
(
-
1
,
atoms
,
-
1
)
neighbor_species
=
species_
.
gather
(
-
1
,
indices
)
neighbor_distances
=
distances
.
index_select
(
-
1
,
in_cutoff
)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
indices_
=
indices
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
3
)
neighbor_coordinates
=
vec
.
gather
(
-
2
,
indices_
)
return
neighbor_species
,
neighbor_distances
,
neighbor_coordinates
class
AEVComputer
(
torch
.
nn
.
Module
):
class
AEVComputer
(
torch
.
nn
.
Module
):
r
"""The AEV computer that takes coordinates as input and outputs aevs.
r
"""The AEV computer that takes coordinates as input and outputs aevs.
Arguments:
Arguments:
Rcr (
:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
Rcr (
float): :math:`R_C` in equation (2) when used at equation (3)
equation (2) when used at equation (3)
in the `ANI paper`_.
in the `ANI paper`_.
Rca (
:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
Rca (
float): :math:`R_C` in equation (2) when used at equation (4)
equation (2) when used at equation (4)
in the `ANI paper`_.
in the `ANI paper`_.
EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in
EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in
equation (3) in the `ANI paper`_.
equation (3) in the `ANI paper`_.
ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in
ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in
...
@@ -33,16 +65,26 @@ class AEVComputer(torch.nn.Module):
...
@@ -33,16 +65,26 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
species and coordinates tensor have the same shape convention as
the input of :class:`AEVComputer`. The returned neighbor
species and coordinates tensor must have shape ``(C, A, N)`` and
``(C, A, N, 3)`` correspoindingly, where ``C`` is the number of
conformations in a chunk, ``A`` is the number of atoms, and ``N``
is the maximum number of neighbors that an atom could have.
.. _ANI paper:
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
"""
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
def
__init__
(
self
,
Rcr
,
Rca
,
EtaR
,
ShfR
,
EtaA
,
Zeta
,
ShfA
,
ShfZ
,
num_species
):
num_species
,
neighborlist_computer
=
default_neighborlist
):
super
(
AEVComputer
,
self
).
__init__
()
super
(
AEVComputer
,
self
).
__init__
()
self
.
register_buffer
(
'
Rcr
'
,
Rcr
)
self
.
Rcr
=
Rcr
self
.
register_buffer
(
'
Rca
'
,
Rca
)
self
.
Rca
=
Rca
# convert constant tensors to a ready-to-broadcast shape
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
# shape convension (..., EtaR, ShfR)
self
.
register_buffer
(
'EtaR'
,
EtaR
.
view
(
-
1
,
1
))
self
.
register_buffer
(
'EtaR'
,
EtaR
.
view
(
-
1
,
1
))
...
@@ -54,6 +96,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -54,6 +96,7 @@ class AEVComputer(torch.nn.Module):
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
register_buffer
(
'ShfZ'
,
ShfZ
.
view
(
1
,
1
,
1
,
-
1
))
self
.
num_species
=
num_species
self
.
num_species
=
num_species
self
.
neighborlist
=
neighborlist_computer
def
radial_sublength
(
self
):
def
radial_sublength
(
self
):
"""Returns the length of radial subaev of a single species"""
"""Returns the length of radial subaev of a single species"""
...
@@ -147,33 +190,11 @@ class AEVComputer(torch.nn.Module):
...
@@ -147,33 +190,11 @@ class AEVComputer(torch.nn.Module):
cutoff radius are valid. The returned indices stores the source of data
cutoff radius are valid. The returned indices stores the source of data
before sorting.
before sorting.
"""
"""
max_cutoff
=
max
([
self
.
Rcr
,
self
.
Rca
])
vec
=
coordinates
.
unsqueeze
(
2
)
-
coordinates
.
unsqueeze
(
1
)
species_
,
distances
,
vec
=
self
.
neighborlist
(
species
,
coordinates
,
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""
max_cutoff
)
distances
=
vec
.
norm
(
2
,
-
1
)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask
=
(
species
==
-
1
).
unsqueeze
(
1
)
distances
=
distances
.
masked_fill
(
padding_mask
,
math
.
inf
)
distances
,
indices
=
distances
.
sort
(
-
1
)
min_distances
,
_
=
distances
.
flatten
(
end_dim
=
1
).
min
(
0
)
inRcr
=
(
min_distances
<=
self
.
Rcr
).
nonzero
().
flatten
()[
1
:]
inRca
=
(
min_distances
<=
self
.
Rca
).
nonzero
().
flatten
()[
1
:]
distances
=
distances
.
index_select
(
-
1
,
inRcr
)
indices_r
=
indices
.
index_select
(
-
1
,
inRcr
)
radial_terms
=
self
.
_radial_subaev_terms
(
distances
)
radial_terms
=
self
.
_radial_subaev_terms
(
distances
)
indices_a
=
indices
.
index_select
(
-
1
,
inRca
)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
_indices_a
=
indices_a
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
-
1
,
3
)
vec
=
vec
.
gather
(
-
2
,
_indices_a
)
vec
=
self
.
_combinations
(
vec
,
-
2
)
vec
=
self
.
_combinations
(
vec
,
-
2
)
angular_terms
=
self
.
_angular_subaev_terms
(
*
vec
)
angular_terms
=
self
.
_angular_subaev_terms
(
*
vec
)
...
@@ -182,7 +203,7 @@ class AEVComputer(torch.nn.Module):
...
@@ -182,7 +203,7 @@ class AEVComputer(torch.nn.Module):
# (conformations, atoms, pairs, ``self.angular_sublength()``)
# (conformations, atoms, pairs, ``self.angular_sublength()``)
# (conformations, atoms, neighbors)
# (conformations, atoms, neighbors)
# (conformations, atoms, pairs)
# (conformations, atoms, pairs)
return
radial_terms
,
angular_terms
,
indices_r
,
indic
es_
a
return
radial_terms
,
angular_terms
,
speci
es_
def
_combinations
(
self
,
tensor
,
dim
=
0
):
def
_combinations
(
self
,
tensor
,
dim
=
0
):
# TODO: remove this when combinations is merged into PyTorch
# TODO: remove this when combinations is merged into PyTorch
...
@@ -199,16 +220,14 @@ class AEVComputer(torch.nn.Module):
...
@@ -199,16 +220,14 @@ class AEVComputer(torch.nn.Module):
return
tensor
.
index_select
(
dim
,
index1
),
\
return
tensor
.
index_select
(
dim
,
index1
),
\
tensor
.
index_select
(
dim
,
index2
)
tensor
.
index_select
(
dim
,
index2
)
def
_compute_mask_r
(
self
,
species
,
indices
_r
):
def
_compute_mask_r
(
self
,
species_r
):
"""Get mask of radial terms for each supported species from indices"""
"""Get mask of radial terms for each supported species from indices"""
species_r
=
species
.
gather
(
-
1
,
indices_r
)
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
mask_r
=
(
species_r
.
unsqueeze
(
-
1
)
==
torch
.
arange
(
self
.
num_species
,
device
=
self
.
EtaR
.
device
))
torch
.
arange
(
self
.
num_species
,
device
=
self
.
EtaR
.
device
))
return
mask_r
return
mask_r
def
_compute_mask_a
(
self
,
species
,
indices
_a
,
present_species
):
def
_compute_mask_a
(
self
,
species_a
,
present_species
):
"""Get mask of angular terms for each supported species from indices"""
"""Get mask of angular terms for each supported species from indices"""
species_a
=
species
.
gather
(
-
1
,
indices_a
)
species_a1
,
species_a2
=
self
.
_combinations
(
species_a
,
-
1
)
species_a1
,
species_a2
=
self
.
_combinations
(
species_a
,
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a1
=
(
species_a1
.
unsqueeze
(
-
1
)
==
present_species
).
unsqueeze
(
-
1
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
mask_a2
=
(
species_a2
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
==
present_species
)
...
@@ -283,14 +302,10 @@ class AEVComputer(torch.nn.Module):
...
@@ -283,14 +302,10 @@ class AEVComputer(torch.nn.Module):
present_species
=
utils
.
present_species
(
species
)
present_species
=
utils
.
present_species
(
species
)
# TODO: remove this workaround after gather support broadcasting
radial_terms
,
angular_terms
,
species_
=
\
atoms
=
coordinates
.
shape
[
1
]
species_
=
species
.
unsqueeze
(
1
).
expand
(
-
1
,
atoms
,
-
1
)
radial_terms
,
angular_terms
,
indices_r
,
indices_a
=
\
self
.
_terms_and_indices
(
species
,
coordinates
)
self
.
_terms_and_indices
(
species
,
coordinates
)
mask_r
=
self
.
_compute_mask_r
(
species_
,
indices_r
)
mask_r
=
self
.
_compute_mask_r
(
species_
)
mask_a
=
self
.
_compute_mask_a
(
species_
,
indices_a
,
present_species
)
mask_a
=
self
.
_compute_mask_a
(
species_
,
present_species
)
radial
,
angular
=
self
.
_assemble
(
radial_terms
,
angular_terms
,
radial
,
angular
=
self
.
_assemble
(
radial_terms
,
angular_terms
,
present_species
,
mask_r
,
mask_a
)
present_species
,
mask_r
,
mask_a
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment