Unverified Commit 7c253794 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Separate out neighborlist computer (#119)

parent 39137175
......@@ -12,14 +12,46 @@ def _cutoff_cosine(distances, cutoff):
)
def default_neighborlist(species, coordinates, cutoff):
"""Default neighborlist computer"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""
distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
in_cutoff = (min_distances <= cutoff).nonzero().flatten()[1:]
indices = indices.index_select(-1, in_cutoff)
# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
neighbor_species = species_.gather(-1, indices)
neighbor_distances = distances.index_select(-1, in_cutoff)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
indices_ = indices.unsqueeze(-1).expand(-1, -1, -1, 3)
neighbor_coordinates = vec.gather(-2, indices_)
return neighbor_species, neighbor_distances, neighbor_coordinates
class AEVComputer(torch.nn.Module):
r"""The AEV computer that takes coordinates as input and outputs aevs.
Arguments:
Rcr (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
equation (2) when used at equation (3) in the `ANI paper`_.
Rca (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
equation (2) when used at equation (4) in the `ANI paper`_.
Rcr (float): :math:`R_C` in equation (2) when used at equation (3)
in the `ANI paper`_.
Rca (float): :math:`R_C` in equation (2) when used at equation (4)
in the `ANI paper`_.
EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in
equation (3) in the `ANI paper`_.
ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in
......@@ -33,16 +65,26 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
species and coordinates tensor have the same shape convention as
the input of :class:`AEVComputer`. The returned neighbor
species and coordinates tensor must have shape ``(C, A, N)`` and
``(C, A, N, 3)`` correspoindingly, where ``C`` is the number of
conformations in a chunk, ``A`` is the number of atoms, and ``N``
is the maximum number of neighbors that an atom could have.
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species):
num_species, neighborlist_computer=default_neighborlist):
super(AEVComputer, self).__init__()
self.register_buffer('Rcr', Rcr)
self.register_buffer('Rca', Rca)
self.Rcr = Rcr
self.Rca = Rca
# convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1))
......@@ -54,6 +96,7 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.num_species = num_species
self.neighborlist = neighborlist_computer
def radial_sublength(self):
"""Returns the length of radial subaev of a single species"""
......@@ -147,33 +190,11 @@ class AEVComputer(torch.nn.Module):
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""
distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
inRcr = (min_distances <= self.Rcr).nonzero().flatten()[1:]
inRca = (min_distances <= self.Rca).nonzero().flatten()[1:]
distances = distances.index_select(-1, inRcr)
indices_r = indices.index_select(-1, inRcr)
max_cutoff = max([self.Rcr, self.Rca])
species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff)
radial_terms = self._radial_subaev_terms(distances)
indices_a = indices.index_select(-1, inRca)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
_indices_a = indices_a.unsqueeze(-1).expand(-1, -1, -1, 3)
vec = vec.gather(-2, _indices_a)
vec = self._combinations(vec, -2)
angular_terms = self._angular_subaev_terms(*vec)
......@@ -182,7 +203,7 @@ class AEVComputer(torch.nn.Module):
# (conformations, atoms, pairs, ``self.angular_sublength()``)
# (conformations, atoms, neighbors)
# (conformations, atoms, pairs)
return radial_terms, angular_terms, indices_r, indices_a
return radial_terms, angular_terms, species_
def _combinations(self, tensor, dim=0):
# TODO: remove this when combinations is merged into PyTorch
......@@ -199,16 +220,14 @@ class AEVComputer(torch.nn.Module):
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
def _compute_mask_r(self, species, indices_r):
def _compute_mask_r(self, species_r):
"""Get mask of radial terms for each supported species from indices"""
species_r = species.gather(-1, indices_r)
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(self.num_species, device=self.EtaR.device))
return mask_r
def _compute_mask_a(self, species, indices_a, present_species):
def _compute_mask_a(self, species_a, present_species):
"""Get mask of angular terms for each supported species from indices"""
species_a = species.gather(-1, indices_a)
species_a1, species_a2 = self._combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
......@@ -283,14 +302,10 @@ class AEVComputer(torch.nn.Module):
present_species = utils.present_species(species)
# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
radial_terms, angular_terms, indices_r, indices_a = \
radial_terms, angular_terms, species_ = \
self._terms_and_indices(species, coordinates)
mask_r = self._compute_mask_r(species_, indices_r)
mask_a = self._compute_mask_a(species_, indices_a, present_species)
mask_r = self._compute_mask_r(species_)
mask_a = self._compute_mask_a(species_, present_species)
radial, angular = self._assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment