Unverified Commit 5da2adbc authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Reduce number of kernel calls in AEV (#440)



* Reduce number of kernel calls in AEV

* try

* Revert "try"

This reverts commit 04e56fde671c5168cb7825ab3c33f64a24196d98.

* more

* more

* merge more for angular terms

* neighbor_pairs

* more

* more

* save

* save

* save

* save
Co-authored-by: default avatarFarhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
parent 06cdce78
......@@ -217,7 +217,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......@@ -234,7 +235,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......@@ -254,7 +256,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[j] = new_j
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......@@ -269,7 +272,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......
......@@ -42,7 +42,7 @@ def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> T
def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
ShfA: Tensor, vectors1: Tensor, vectors2: Tensor) -> Tensor:
ShfA: Tensor, vectors12: Tensor) -> Tensor:
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
......@@ -55,21 +55,18 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1 = vectors1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5)
vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances12 = vectors12.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * torch.nn.functional.cosine_similarity(vectors1, vectors2, dim=-5)
cos_angles = 0.95 * torch.nn.functional.cosine_similarity(vectors12[0], vectors12[1], dim=-5)
angles = torch.acos(cos_angles)
fcj1 = cutoff_cosine(distances1, Rca)
fcj2 = cutoff_cosine(distances2, Rca)
fcj12 = cutoff_cosine(distances12, Rca)
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA * ((distances1 + distances2) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2
factor2 = torch.exp(-EtaA * (distances12.sum(0) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj12.prod(0)
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
......@@ -120,7 +117,7 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors
Arguments:
......@@ -136,41 +133,42 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
coordinates = coordinates.detach()
cell = cell.detach()
num_atoms = padding_mask.shape[1]
num_mols = padding_mask.shape[0]
all_atoms = torch.arange(num_atoms, device=cell.device)
# Step 2: center cell
# torch.triu_indices is faster than combinations
p1_center, p2_center = torch.triu_indices(num_atoms, num_atoms, 1,
device=cell.device).unbind(0)
shifts_center = shifts.new_zeros((p1_center.shape[0], 3))
p12_center = torch.triu_indices(num_atoms, num_atoms, 1, device=cell.device)
shifts_center = shifts.new_zeros((p12_center.shape[1], 3))
# Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3)
num_shifts = shifts.shape[0]
all_shifts = torch.arange(num_shifts, device=cell.device)
shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1)
prod = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).t()
shift_index = prod[0]
p12 = prod[1:]
shifts_outide = shifts.index_select(0, shift_index)
# Step 4: combine results for all cells
shifts_all = torch.cat([shifts_center, shifts_outide])
p1_all = torch.cat([p1_center, p1])
p2_all = torch.cat([p2_center, p2])
p12_all = torch.cat([p12_center, p12], dim=1)
shift_values = shifts_all.to(cell.dtype) @ cell
# step 5, compute distances, and find all pairs within cutoff
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all) + shift_values).norm(2, -1)
padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all))
selected_coordinates = coordinates.index_select(1, p12_all.view(-1)).view(num_mols, 2, -1, 3)
distances = (selected_coordinates[:, 0, ...] - selected_coordinates[:, 1, ...] + shift_values).norm(2, -1)
padding_mask = padding_mask.index_select(1, p12_all.view(-1)).view(2, -1).any(0)
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
atom_index1 = p1_all[pair_index]
atom_index2 = p2_all[pair_index]
atom_index12 = p12_all[:, pair_index]
shifts = shifts_all.index_select(0, pair_index)
return molecule_index + atom_index1, molecule_index + atom_index2, shifts
return molecule_index + atom_index12, shifts
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors (doesn't use PBC)
This function bypasses the calculation of shifts and duplication
......@@ -180,26 +178,27 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa
padding_mask (:class:`torch.Tensor`): boolean tensor of shape
(molecules, atoms) for padding mask. 1 == is padding.
coordinates (:class:`torch.Tensor`): tensor of shape
(molecules, atoms, 3) for atom coordinates.
(molecules * atoms, 3) for atom coordinates.
cutoff (float): the cutoff inside which atoms are considered pairs
"""
coordinates = coordinates.detach()
current_device = coordinates.device
num_atoms = padding_mask.shape[1]
p1_all, p2_all = torch.triu_indices(num_atoms, num_atoms, 1,
device=current_device).unbind(0)
num_mols = padding_mask.shape[0]
p12_all = torch.triu_indices(num_atoms, num_atoms, 1, device=current_device)
p12_all_flattened = p12_all.view(-1)
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all)).norm(2, -1)
padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all))
pair_coordinates = coordinates.index_select(1, p12_all_flattened).view(num_mols, 2, -1, 3)
distances = (pair_coordinates[:, 0, ...] - pair_coordinates[:, 1, ...]).norm(2, -1)
padding_mask = padding_mask.index_select(1, p12_all_flattened).view(num_mols, 2, -1).any(dim=1)
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
atom_index1 = p1_all[pair_index] + molecule_index
atom_index2 = p2_all[pair_index] + molecule_index
atom_index12 = p12_all[:, pair_index] + molecule_index
# shifts
shifts = p1_all.new_zeros((p1_all.shape[0], 3)).index_select(0, pair_index)
return atom_index1, atom_index2, shifts
shifts = coordinates.new_zeros((pair_index.shape[0], 3))
return atom_index12, shifts
def triu_index(num_species: int) -> Tensor:
......@@ -217,7 +216,7 @@ def cumsum_from_zero(input_: Tensor) -> Tensor:
return cumsum
def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
......@@ -229,7 +228,7 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
# convert representation from pair to central-others
ai1 = torch.cat([atom_index1, atom_index2])
ai1 = atom_index12.view(-1)
sorted_ai1, rev_indices = ai1.sort()
# sort and compute unique key
......@@ -245,22 +244,18 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
# do local combinations within unique key, assuming sorted
m = counts.max().item() if counts.numel() > 0 else 0
n = pair_sizes.shape[0]
intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).t().unsqueeze(0).expand(n, -1, -1)
mask = (torch.arange(intra_pair_indices.shape[1], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
sorted_local_index1, sorted_local_index2 = intra_pair_indices.flatten(0, 1)[mask, :].unbind(-1)
cumsum = cumsum_from_zero(counts).index_select(0, pair_indices)
sorted_local_index1 += cumsum
sorted_local_index2 += cumsum
intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1)
mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask]
sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices)
# unsort result from last part
local_index1 = rev_indices[sorted_local_index1]
local_index2 = rev_indices[sorted_local_index2]
local_index12 = rev_indices[sorted_local_index12]
# compute mapping between representation of central-other to pair
n = atom_index1.shape[0]
sign1 = ((local_index1 < n).to(torch.long) * 2) - 1
sign2 = ((local_index2 < n).to(torch.long) * 2) - 1
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
n = atom_index12.shape[1]
sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1
return central_atom_index, local_index12 % n, sign12
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
......@@ -274,45 +269,42 @@ def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
num_species_pairs = angular_length // angular_sublength
# PBC calculation is bypassed if there are no shifts
if shifts.numel() == 0:
atom_index1, atom_index2, shifts = neighbor_pairs_nopbc(species == -1, coordinates, Rcr)
atom_index12, shifts = neighbor_pairs_nopbc(species == -1, coordinates, Rcr)
else:
atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr)
atom_index12, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr)
coordinates = coordinates.flatten(0, 1)
shift_values = shifts.to(cell.dtype) @ cell
vec = coordinates.index_select(0, atom_index1) - coordinates.index_select(0, atom_index2) + shift_values
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1))
selected_coordinates = selected_coordinates.view(2, -1, 3)
vec = selected_coordinates[0] - selected_coordinates[1] + shift_values
species = species.flatten()
species1 = species[atom_index1]
species2 = species[atom_index2]
species12 = species[atom_index12]
distances = vec.norm(2, -1)
# compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros((num_molecules * num_atoms * num_species, radial_sublength))
index1 = atom_index1 * num_species + species2
index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_)
radial_aev.index_add_(0, index2, radial_terms_)
index12 = atom_index12 * num_species + species12.flip(0)
radial_aev.index_add_(0, index12[0], radial_terms_)
radial_aev.index_add_(0, index12[1], radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# Rca is usually much smaller than Rcr, using neighbor list with cutoff=Rcr is a waste of resources
# Now we will get a smaller neighbor list that only cares about atoms with distances <= Rca
even_closer_indices = (distances <= Rca).nonzero().flatten()
atom_index1 = atom_index1.index_select(0, even_closer_indices)
atom_index2 = atom_index2.index_select(0, even_closer_indices)
species1 = species1.index_select(0, even_closer_indices)
species2 = species2.index_select(0, even_closer_indices)
atom_index12 = atom_index12.index_select(1, even_closer_indices)
species12 = species12.index_select(1, even_closer_indices)
vec = vec.index_select(0, even_closer_indices)
# compute angular aev
central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(atom_index1, atom_index2)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype)
vec2 = vec.index_select(0, pair_index2) * sign2.unsqueeze(1).to(vec.dtype)
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
central_atom_index, pair_index12, sign12 = triple_by_molecule(atom_index12)
species12_small = species12[:, pair_index12]
vec12 = vec.index_select(0, pair_index12.view(-1)).view(2, -1, 3) * sign12.unsqueeze(-1)
species12_ = torch.where(sign12 == 1, species12_small[1], species12_small[0])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec12)
angular_aev = angular_terms_.new_zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength))
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
index = central_atom_index * num_species_pairs + triu_index[species12_[0], species12_[1]]
angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
return torch.cat([radial_aev, angular_aev], dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment