Unverified Commit 7c253794 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Separate out neighborlist computer (#119)

parent 39137175
...@@ -12,14 +12,46 @@ def _cutoff_cosine(distances, cutoff): ...@@ -12,14 +12,46 @@ def _cutoff_cosine(distances, cutoff):
) )
def default_neighborlist(species, coordinates, cutoff):
"""Default neighborlist computer"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors"""
distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
in_cutoff = (min_distances <= cutoff).nonzero().flatten()[1:]
indices = indices.index_select(-1, in_cutoff)
# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
neighbor_species = species_.gather(-1, indices)
neighbor_distances = distances.index_select(-1, in_cutoff)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
indices_ = indices.unsqueeze(-1).expand(-1, -1, -1, 3)
neighbor_coordinates = vec.gather(-2, indices_)
return neighbor_species, neighbor_distances, neighbor_coordinates
class AEVComputer(torch.nn.Module): class AEVComputer(torch.nn.Module):
r"""The AEV computer that takes coordinates as input and outputs aevs. r"""The AEV computer that takes coordinates as input and outputs aevs.
Arguments: Arguments:
Rcr (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in Rcr (float): :math:`R_C` in equation (2) when used at equation (3)
equation (2) when used at equation (3) in the `ANI paper`_. in the `ANI paper`_.
Rca (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in Rca (float): :math:`R_C` in equation (2) when used at equation (4)
equation (2) when used at equation (4) in the `ANI paper`_. in the `ANI paper`_.
EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in
equation (3) in the `ANI paper`_. equation (3) in the `ANI paper`_.
ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in
...@@ -33,16 +65,26 @@ class AEVComputer(torch.nn.Module): ...@@ -33,16 +65,26 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_. equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types. num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
species and coordinates tensor have the same shape convention as
the input of :class:`AEVComputer`. The returned neighbor
species and coordinates tensor must have shape ``(C, A, N)`` and
``(C, A, N, 3)`` correspoindingly, where ``C`` is the number of
conformations in a chunk, ``A`` is the number of atoms, and ``N``
is the maximum number of neighbors that an atom could have.
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species): num_species, neighborlist_computer=default_neighborlist):
super(AEVComputer, self).__init__() super(AEVComputer, self).__init__()
self.register_buffer('Rcr', Rcr) self.Rcr = Rcr
self.register_buffer('Rca', Rca) self.Rca = Rca
# convert constant tensors to a ready-to-broadcast shape # convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR) # shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1)) self.register_buffer('EtaR', EtaR.view(-1, 1))
...@@ -54,6 +96,7 @@ class AEVComputer(torch.nn.Module): ...@@ -54,6 +96,7 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1)) self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.num_species = num_species self.num_species = num_species
self.neighborlist = neighborlist_computer
def radial_sublength(self): def radial_sublength(self):
"""Returns the length of radial subaev of a single species""" """Returns the length of radial subaev of a single species"""
...@@ -147,33 +190,11 @@ class AEVComputer(torch.nn.Module): ...@@ -147,33 +190,11 @@ class AEVComputer(torch.nn.Module):
cutoff radius are valid. The returned indices stores the source of data cutoff radius are valid. The returned indices stores the source of data
before sorting. before sorting.
""" """
max_cutoff = max([self.Rcr, self.Rca])
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1) species_, distances, vec = self.neighborlist(species, coordinates,
"""Shape (conformations, atoms, atoms, 3) storing Rij vectors""" max_cutoff)
distances = vec.norm(2, -1)
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
inRcr = (min_distances <= self.Rcr).nonzero().flatten()[1:]
inRca = (min_distances <= self.Rca).nonzero().flatten()[1:]
distances = distances.index_select(-1, inRcr)
indices_r = indices.index_select(-1, inRcr)
radial_terms = self._radial_subaev_terms(distances) radial_terms = self._radial_subaev_terms(distances)
indices_a = indices.index_select(-1, inRca)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
_indices_a = indices_a.unsqueeze(-1).expand(-1, -1, -1, 3)
vec = vec.gather(-2, _indices_a)
vec = self._combinations(vec, -2) vec = self._combinations(vec, -2)
angular_terms = self._angular_subaev_terms(*vec) angular_terms = self._angular_subaev_terms(*vec)
...@@ -182,7 +203,7 @@ class AEVComputer(torch.nn.Module): ...@@ -182,7 +203,7 @@ class AEVComputer(torch.nn.Module):
# (conformations, atoms, pairs, ``self.angular_sublength()``) # (conformations, atoms, pairs, ``self.angular_sublength()``)
# (conformations, atoms, neighbors) # (conformations, atoms, neighbors)
# (conformations, atoms, pairs) # (conformations, atoms, pairs)
return radial_terms, angular_terms, indices_r, indices_a return radial_terms, angular_terms, species_
def _combinations(self, tensor, dim=0): def _combinations(self, tensor, dim=0):
# TODO: remove this when combinations is merged into PyTorch # TODO: remove this when combinations is merged into PyTorch
...@@ -199,16 +220,14 @@ class AEVComputer(torch.nn.Module): ...@@ -199,16 +220,14 @@ class AEVComputer(torch.nn.Module):
return tensor.index_select(dim, index1), \ return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2) tensor.index_select(dim, index2)
def _compute_mask_r(self, species, indices_r): def _compute_mask_r(self, species_r):
"""Get mask of radial terms for each supported species from indices""" """Get mask of radial terms for each supported species from indices"""
species_r = species.gather(-1, indices_r)
mask_r = (species_r.unsqueeze(-1) == mask_r = (species_r.unsqueeze(-1) ==
torch.arange(self.num_species, device=self.EtaR.device)) torch.arange(self.num_species, device=self.EtaR.device))
return mask_r return mask_r
def _compute_mask_a(self, species, indices_a, present_species): def _compute_mask_a(self, species_a, present_species):
"""Get mask of angular terms for each supported species from indices""" """Get mask of angular terms for each supported species from indices"""
species_a = species.gather(-1, indices_a)
species_a1, species_a2 = self._combinations(species_a, -1) species_a1, species_a2 = self._combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1) mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species) mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
...@@ -283,14 +302,10 @@ class AEVComputer(torch.nn.Module): ...@@ -283,14 +302,10 @@ class AEVComputer(torch.nn.Module):
present_species = utils.present_species(species) present_species = utils.present_species(species)
# TODO: remove this workaround after gather support broadcasting radial_terms, angular_terms, species_ = \
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
radial_terms, angular_terms, indices_r, indices_a = \
self._terms_and_indices(species, coordinates) self._terms_and_indices(species, coordinates)
mask_r = self._compute_mask_r(species_, indices_r) mask_r = self._compute_mask_r(species_)
mask_a = self._compute_mask_a(species_, indices_a, present_species) mask_a = self._compute_mask_a(species_, present_species)
radial, angular = self._assemble(radial_terms, angular_terms, radial, angular = self._assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a) present_species, mask_r, mask_a)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment