Unverified Commit 5da2adbc authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Reduce number of kernel calls in AEV (#440)



* Reduce number of kernel calls in AEV

* try

* Revert "try"

This reverts commit 04e56fde671c5168cb7825ab3c33f64a24196d98.

* more

* more

* merge more for angular terms

* neighbor_pairs

* more

* more

* save

* save

* save

* save
Co-authored-by: default avatarFarhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
parent 06cdce78
...@@ -217,7 +217,8 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -217,7 +217,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
for xyz2 in xyz2s: for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
...@@ -234,7 +235,8 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -234,7 +235,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[i] = 9.9 xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
...@@ -254,7 +256,8 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -254,7 +256,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[j] = new_j xyz2[j] = new_j
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
...@@ -269,7 +272,8 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -269,7 +272,8 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double) xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index12, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
atom_index1, atom_index2 = atom_index12.unbind(0)
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
......
...@@ -42,7 +42,7 @@ def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> T ...@@ -42,7 +42,7 @@ def radial_terms(Rcr: float, EtaR: Tensor, ShfR: Tensor, distances: Tensor) -> T
def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor, def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
ShfA: Tensor, vectors1: Tensor, vectors2: Tensor) -> Tensor: ShfA: Tensor, vectors12: Tensor) -> Tensor:
"""Compute the angular subAEV terms of the center atom given neighbor pairs. """Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just This correspond to equation (4) in the `ANI paper`_. This function just
...@@ -55,21 +55,18 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor, ...@@ -55,21 +55,18 @@ def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
vectors1 = vectors1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) vectors12 = vectors12.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) distances12 = vectors12.norm(2, dim=-5)
distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from # 0.95 is multiplied to the cos values to prevent acos from
# returning NaN. # returning NaN.
cos_angles = 0.95 * torch.nn.functional.cosine_similarity(vectors1, vectors2, dim=-5) cos_angles = 0.95 * torch.nn.functional.cosine_similarity(vectors12[0], vectors12[1], dim=-5)
angles = torch.acos(cos_angles) angles = torch.acos(cos_angles)
fcj1 = cutoff_cosine(distances1, Rca) fcj12 = cutoff_cosine(distances12, Rca)
fcj2 = cutoff_cosine(distances2, Rca)
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA * ((distances1 + distances2) / 2 - ShfA) ** 2) factor2 = torch.exp(-EtaA * (distances12.sum(0) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2 ret = 2 * factor1 * factor2 * fcj12.prod(0)
# At this point, ret now have shape # At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants. # (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one # We then should flat the last 4 dimensions to view the subAEV as one
...@@ -120,7 +117,7 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor: ...@@ -120,7 +117,7 @@ def compute_shifts(cell: Tensor, pbc: Tensor, cutoff: float) -> Tensor:
def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]: shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors """Compute pairs of atoms that are neighbors
Arguments: Arguments:
...@@ -136,41 +133,42 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, ...@@ -136,41 +133,42 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
coordinates = coordinates.detach() coordinates = coordinates.detach()
cell = cell.detach() cell = cell.detach()
num_atoms = padding_mask.shape[1] num_atoms = padding_mask.shape[1]
num_mols = padding_mask.shape[0]
all_atoms = torch.arange(num_atoms, device=cell.device) all_atoms = torch.arange(num_atoms, device=cell.device)
# Step 2: center cell # Step 2: center cell
# torch.triu_indices is faster than combinations # torch.triu_indices is faster than combinations
p1_center, p2_center = torch.triu_indices(num_atoms, num_atoms, 1, p12_center = torch.triu_indices(num_atoms, num_atoms, 1, device=cell.device)
device=cell.device).unbind(0) shifts_center = shifts.new_zeros((p12_center.shape[1], 3))
shifts_center = shifts.new_zeros((p1_center.shape[0], 3))
# Step 3: cells with shifts # Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3) # shape convention (shift index, molecule index, atom index, 3)
num_shifts = shifts.shape[0] num_shifts = shifts.shape[0]
all_shifts = torch.arange(num_shifts, device=cell.device) all_shifts = torch.arange(num_shifts, device=cell.device)
shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1) prod = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).t()
shift_index = prod[0]
p12 = prod[1:]
shifts_outide = shifts.index_select(0, shift_index) shifts_outide = shifts.index_select(0, shift_index)
# Step 4: combine results for all cells # Step 4: combine results for all cells
shifts_all = torch.cat([shifts_center, shifts_outide]) shifts_all = torch.cat([shifts_center, shifts_outide])
p1_all = torch.cat([p1_center, p1]) p12_all = torch.cat([p12_center, p12], dim=1)
p2_all = torch.cat([p2_center, p2])
shift_values = shifts_all.to(cell.dtype) @ cell shift_values = shifts_all.to(cell.dtype) @ cell
# step 5, compute distances, and find all pairs within cutoff # step 5, compute distances, and find all pairs within cutoff
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all) + shift_values).norm(2, -1) selected_coordinates = coordinates.index_select(1, p12_all.view(-1)).view(num_mols, 2, -1, 3)
padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all)) distances = (selected_coordinates[:, 0, ...] - selected_coordinates[:, 1, ...] + shift_values).norm(2, -1)
padding_mask = padding_mask.index_select(1, p12_all.view(-1)).view(2, -1).any(0)
distances.masked_fill_(padding_mask, math.inf) distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero() in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1) molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms molecule_index *= num_atoms
atom_index1 = p1_all[pair_index] atom_index12 = p12_all[:, pair_index]
atom_index2 = p2_all[pair_index]
shifts = shifts_all.index_select(0, pair_index) shifts = shifts_all.index_select(0, pair_index)
return molecule_index + atom_index1, molecule_index + atom_index2, shifts return molecule_index + atom_index12, shifts
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]: def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors (doesn't use PBC) """Compute pairs of atoms that are neighbors (doesn't use PBC)
This function bypasses the calculation of shifts and duplication This function bypasses the calculation of shifts and duplication
...@@ -180,26 +178,27 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa ...@@ -180,26 +178,27 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa
padding_mask (:class:`torch.Tensor`): boolean tensor of shape padding_mask (:class:`torch.Tensor`): boolean tensor of shape
(molecules, atoms) for padding mask. 1 == is padding. (molecules, atoms) for padding mask. 1 == is padding.
coordinates (:class:`torch.Tensor`): tensor of shape coordinates (:class:`torch.Tensor`): tensor of shape
(molecules, atoms, 3) for atom coordinates. (molecules * atoms, 3) for atom coordinates.
cutoff (float): the cutoff inside which atoms are considered pairs cutoff (float): the cutoff inside which atoms are considered pairs
""" """
coordinates = coordinates.detach() coordinates = coordinates.detach()
current_device = coordinates.device current_device = coordinates.device
num_atoms = padding_mask.shape[1] num_atoms = padding_mask.shape[1]
p1_all, p2_all = torch.triu_indices(num_atoms, num_atoms, 1, num_mols = padding_mask.shape[0]
device=current_device).unbind(0) p12_all = torch.triu_indices(num_atoms, num_atoms, 1, device=current_device)
p12_all_flattened = p12_all.view(-1)
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all)).norm(2, -1) pair_coordinates = coordinates.index_select(1, p12_all_flattened).view(num_mols, 2, -1, 3)
padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all)) distances = (pair_coordinates[:, 0, ...] - pair_coordinates[:, 1, ...]).norm(2, -1)
padding_mask = padding_mask.index_select(1, p12_all_flattened).view(num_mols, 2, -1).any(dim=1)
distances.masked_fill_(padding_mask, math.inf) distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero() in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1) molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms molecule_index *= num_atoms
atom_index1 = p1_all[pair_index] + molecule_index atom_index12 = p12_all[:, pair_index] + molecule_index
atom_index2 = p2_all[pair_index] + molecule_index
# shifts # shifts
shifts = p1_all.new_zeros((p1_all.shape[0], 3)).index_select(0, pair_index) shifts = coordinates.new_zeros((pair_index.shape[0], 3))
return atom_index1, atom_index2, shifts return atom_index12, shifts
def triu_index(num_species: int) -> Tensor: def triu_index(num_species: int) -> Tensor:
...@@ -217,7 +216,7 @@ def cumsum_from_zero(input_: Tensor) -> Tensor: ...@@ -217,7 +216,7 @@ def cumsum_from_zero(input_: Tensor) -> Tensor:
return cumsum return cumsum
def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
"""Input: indices for pairs of atoms that are close to each other. """Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists. (2, 1) exists.
...@@ -229,7 +228,7 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor ...@@ -229,7 +228,7 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
""" """
# convert representation from pair to central-others # convert representation from pair to central-others
ai1 = torch.cat([atom_index1, atom_index2]) ai1 = atom_index12.view(-1)
sorted_ai1, rev_indices = ai1.sort() sorted_ai1, rev_indices = ai1.sort()
# sort and compute unique key # sort and compute unique key
...@@ -245,22 +244,18 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor ...@@ -245,22 +244,18 @@ def triple_by_molecule(atom_index1: Tensor, atom_index2: Tensor) -> Tuple[Tensor
# do local combinations within unique key, assuming sorted # do local combinations within unique key, assuming sorted
m = counts.max().item() if counts.numel() > 0 else 0 m = counts.max().item() if counts.numel() > 0 else 0
n = pair_sizes.shape[0] n = pair_sizes.shape[0]
intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).t().unsqueeze(0).expand(n, -1, -1) intra_pair_indices = torch.tril_indices(m, m, -1, device=ai1.device).unsqueeze(1).expand(-1, n, -1)
mask = (torch.arange(intra_pair_indices.shape[1], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten() mask = (torch.arange(intra_pair_indices.shape[2], device=ai1.device) < pair_sizes.unsqueeze(1)).flatten()
sorted_local_index1, sorted_local_index2 = intra_pair_indices.flatten(0, 1)[mask, :].unbind(-1) sorted_local_index12 = intra_pair_indices.flatten(1, 2)[:, mask]
cumsum = cumsum_from_zero(counts).index_select(0, pair_indices) sorted_local_index12 += cumsum_from_zero(counts).index_select(0, pair_indices)
sorted_local_index1 += cumsum
sorted_local_index2 += cumsum
# unsort result from last part # unsort result from last part
local_index1 = rev_indices[sorted_local_index1] local_index12 = rev_indices[sorted_local_index12]
local_index2 = rev_indices[sorted_local_index2]
# compute mapping between representation of central-other to pair # compute mapping between representation of central-other to pair
n = atom_index1.shape[0] n = atom_index12.shape[1]
sign1 = ((local_index1 < n).to(torch.long) * 2) - 1 sign12 = ((local_index12 < n).to(torch.int8) * 2) - 1
sign2 = ((local_index2 < n).to(torch.long) * 2) - 1 return central_atom_index, local_index12 % n, sign12
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor, def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
...@@ -274,45 +269,42 @@ def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor, ...@@ -274,45 +269,42 @@ def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
num_species_pairs = angular_length // angular_sublength num_species_pairs = angular_length // angular_sublength
# PBC calculation is bypassed if there are no shifts # PBC calculation is bypassed if there are no shifts
if shifts.numel() == 0: if shifts.numel() == 0:
atom_index1, atom_index2, shifts = neighbor_pairs_nopbc(species == -1, coordinates, Rcr) atom_index12, shifts = neighbor_pairs_nopbc(species == -1, coordinates, Rcr)
else: else:
atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr) atom_index12, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr)
coordinates = coordinates.flatten(0, 1) coordinates = coordinates.flatten(0, 1)
shift_values = shifts.to(cell.dtype) @ cell shift_values = shifts.to(cell.dtype) @ cell
vec = coordinates.index_select(0, atom_index1) - coordinates.index_select(0, atom_index2) + shift_values selected_coordinates = coordinates.index_select(0, atom_index12.view(-1))
selected_coordinates = selected_coordinates.view(2, -1, 3)
vec = selected_coordinates[0] - selected_coordinates[1] + shift_values
species = species.flatten() species = species.flatten()
species1 = species[atom_index1] species12 = species[atom_index12]
species2 = species[atom_index2]
distances = vec.norm(2, -1) distances = vec.norm(2, -1)
# compute radial aev # compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances) radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros((num_molecules * num_atoms * num_species, radial_sublength)) radial_aev = radial_terms_.new_zeros((num_molecules * num_atoms * num_species, radial_sublength))
index1 = atom_index1 * num_species + species2 index12 = atom_index12 * num_species + species12.flip(0)
index2 = atom_index2 * num_species + species1 radial_aev.index_add_(0, index12[0], radial_terms_)
radial_aev.index_add_(0, index1, radial_terms_) radial_aev.index_add_(0, index12[1], radial_terms_)
radial_aev.index_add_(0, index2, radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length) radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# Rca is usually much smaller than Rcr, using neighbor list with cutoff=Rcr is a waste of resources # Rca is usually much smaller than Rcr, using neighbor list with cutoff=Rcr is a waste of resources
# Now we will get a smaller neighbor list that only cares about atoms with distances <= Rca # Now we will get a smaller neighbor list that only cares about atoms with distances <= Rca
even_closer_indices = (distances <= Rca).nonzero().flatten() even_closer_indices = (distances <= Rca).nonzero().flatten()
atom_index1 = atom_index1.index_select(0, even_closer_indices) atom_index12 = atom_index12.index_select(1, even_closer_indices)
atom_index2 = atom_index2.index_select(0, even_closer_indices) species12 = species12.index_select(1, even_closer_indices)
species1 = species1.index_select(0, even_closer_indices)
species2 = species2.index_select(0, even_closer_indices)
vec = vec.index_select(0, even_closer_indices) vec = vec.index_select(0, even_closer_indices)
# compute angular aev # compute angular aev
central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(atom_index1, atom_index2) central_atom_index, pair_index12, sign12 = triple_by_molecule(atom_index12)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype) species12_small = species12[:, pair_index12]
vec2 = vec.index_select(0, pair_index2) * sign2.unsqueeze(1).to(vec.dtype) vec12 = vec.index_select(0, pair_index12.view(-1)).view(2, -1, 3) * sign12.unsqueeze(-1)
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1]) species12_ = torch.where(sign12 == 1, species12_small[1], species12_small[0])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2]) angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec12)
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength)) angular_aev = angular_terms_.new_zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength))
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_] index = central_atom_index * num_species_pairs + triu_index[species12_[0], species12_[1]]
angular_aev.index_add_(0, index, angular_terms_) angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length) angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
return torch.cat([radial_aev, angular_aev], dim=-1) return torch.cat([radial_aev, angular_aev], dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment