Unverified Commit bacde2e7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove unnecessary molecule index (#206)

parent 1c0f0e76
...@@ -176,8 +176,7 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -176,8 +176,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
for xyz2 in xyz2s: for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
...@@ -194,8 +193,7 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -194,8 +193,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[i] = 9.9 xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
...@@ -215,8 +213,7 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -215,8 +213,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[j] = new_i xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
...@@ -231,8 +228,7 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -231,8 +228,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double) xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0) coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1) atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0]) self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1]) self.assertEqual(atom_index2.tolist(), [1])
......
...@@ -169,10 +169,11 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): ...@@ -169,10 +169,11 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
distances.masked_fill_(padding_mask, math.inf) distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero() in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1) molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
atom_index1 = p1_all[pair_index] atom_index1 = p1_all[pair_index]
atom_index2 = p2_all[pair_index] atom_index2 = p2_all[pair_index]
shifts = shifts_all.index_select(0, pair_index) shifts = shifts_all.index_select(0, pair_index)
return molecule_index, atom_index1, atom_index2, shifts return molecule_index + atom_index1, molecule_index + atom_index2, shifts
# torch.jit.script # torch.jit.script
...@@ -219,7 +220,7 @@ def cumsum_from_zero(input_): ...@@ -219,7 +220,7 @@ def cumsum_from_zero(input_):
# torch.jit.script # torch.jit.script
def triple_by_molecule(molecule_index, atom_index1, atom_index2): def triple_by_molecule(atom_index1, atom_index2):
"""Input: indices for pairs of atoms that are close to each other. """Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists. (2, 1) exists.
...@@ -230,24 +231,20 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2): ...@@ -230,24 +231,20 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4) are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
""" """
# convert representation from pair to central-other # convert representation from pair to central-others
n = molecule_index.shape[0] n = atom_index1.shape[0]
mi = molecule_index.repeat(2)
ai1 = torch.cat([atom_index1, atom_index2]) ai1 = torch.cat([atom_index1, atom_index2])
# sort and compute unique key # sort and compute unique key
mi_ai1 = torch.stack([mi, ai1], dim=1) uniqued_central_atom_index, rev_indices, counts = torch._unique2_temporary_will_remove_soon(ai1, sorted=True, return_inverse=True, return_counts=True)
m_ac, rev_indices, counts = torch._unique_dim2_temporary_will_remove_soon(mi_ai1, dim=0, sorted=True, return_inverse=True, return_counts=True)
uniqued_molecule_index, uniqued_central_atom_index = m_ac.unbind(1)
# do local combinations within unique key, assuming sorted # do local combinations within unique key, assuming sorted
pair_sizes = counts * (counts - 1) // 2 pair_sizes = counts * (counts - 1) // 2
total_size = pair_sizes.sum() total_size = pair_sizes.sum()
molecule_index = torch.repeat_interleave(uniqued_molecule_index, pair_sizes)
central_atom_index = torch.repeat_interleave(uniqued_central_atom_index, pair_sizes) central_atom_index = torch.repeat_interleave(uniqued_central_atom_index, pair_sizes)
cumsum = cumsum_from_zero(pair_sizes) cumsum = cumsum_from_zero(pair_sizes)
cumsum = torch.repeat_interleave(cumsum, pair_sizes) cumsum = torch.repeat_interleave(cumsum, pair_sizes)
sorted_local_pair_index = torch.arange(total_size, device=molecule_index.device) - cumsum sorted_local_pair_index = torch.arange(total_size, device=cumsum.device) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index) sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts) cumsum = cumsum_from_zero(counts)
cumsum = torch.repeat_interleave(cumsum, pair_sizes) cumsum = torch.repeat_interleave(cumsum, pair_sizes)
...@@ -264,7 +261,7 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2): ...@@ -264,7 +261,7 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
sign2 = torch.where(local_index2 < n, torch.ones_like(local_index2), -torch.ones_like(local_index2)) sign2 = torch.where(local_index2 < n, torch.ones_like(local_index2), -torch.ones_like(local_index2))
pair_index1 = torch.where(local_index1 < n, local_index1, local_index1 - n) pair_index1 = torch.where(local_index1 < n, local_index1, local_index1 - n)
pair_index2 = torch.where(local_index2 < n, local_index2, local_index2 - n) pair_index2 = torch.where(local_index2 < n, local_index2, local_index2 - n)
return molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2 return central_atom_index, pair_index1, pair_index2, sign1, sign2
# torch.jit.script # torch.jit.script
...@@ -276,32 +273,34 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -276,32 +273,34 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
num_species_pairs = angular_length // angular_sublength num_species_pairs = angular_length // angular_sublength
cutoff = max(Rcr, Rca) cutoff = max(Rcr, Rca)
molecule_index, atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff) atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff)
species1 = species[molecule_index, atom_index1] species = species.flatten()
species2 = species[molecule_index, atom_index2] coordinates = coordinates.flatten(0, 1)
species1 = species[atom_index1]
species2 = species[atom_index2]
shift_values = torch.mm(shifts.to(cell.dtype), cell) shift_values = torch.mm(shifts.to(cell.dtype), cell)
vec = coordinates[molecule_index, atom_index1, :] - coordinates[molecule_index, atom_index2, :] + shift_values vec = coordinates.index_select(0, atom_index1) - coordinates.index_select(0, atom_index2) + shift_values
distances = vec.norm(2, -1) distances = vec.norm(2, -1)
# compute radial aev # compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances) radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength) radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength)
index1 = (molecule_index * num_atoms + atom_index1) * num_species + species2 index1 = atom_index1 * num_species + species2
index2 = (molecule_index * num_atoms + atom_index2) * num_species + species1 index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_) radial_aev.index_add_(0, index1, radial_terms_)
radial_aev.index_add_(0, index2, radial_terms_) radial_aev.index_add_(0, index2, radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length) radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# compute angular aev # compute angular aev
molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(molecule_index, atom_index1, atom_index2) central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(atom_index1, atom_index2)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype) vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype)
vec2 = vec.index_select(0, pair_index2) * sign2.unsqueeze(1).to(vec.dtype) vec2 = vec.index_select(0, pair_index2) * sign2.unsqueeze(1).to(vec.dtype)
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1]) species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2]) species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2) angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength) angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength)
index = (molecule_index * num_atoms + central_atom_index) * num_species_pairs + triu_index[species1_, species2_] index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
angular_aev.index_add_(0, index, angular_terms_) angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length) angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
return torch.cat([radial_aev, angular_aev], dim=-1) return torch.cat([radial_aev, angular_aev], dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment