Unverified Commit bacde2e7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove unnecessary molecule index (#206)

parent 1c0f0e76
......@@ -176,8 +176,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......@@ -194,8 +193,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......@@ -215,8 +213,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......@@ -231,8 +228,7 @@ class TestPBCSeeEachOther(unittest.TestCase):
xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
......
......@@ -169,10 +169,11 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
atom_index1 = p1_all[pair_index]
atom_index2 = p2_all[pair_index]
shifts = shifts_all.index_select(0, pair_index)
return molecule_index, atom_index1, atom_index2, shifts
return molecule_index + atom_index1, molecule_index + atom_index2, shifts
# torch.jit.script
......@@ -219,7 +220,7 @@ def cumsum_from_zero(input_):
# torch.jit.script
def triple_by_molecule(molecule_index, atom_index1, atom_index2):
def triple_by_molecule(atom_index1, atom_index2):
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
......@@ -230,24 +231,20 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
# convert representation from pair to central-other
n = molecule_index.shape[0]
mi = molecule_index.repeat(2)
# convert representation from pair to central-others
n = atom_index1.shape[0]
ai1 = torch.cat([atom_index1, atom_index2])
# sort and compute unique key
mi_ai1 = torch.stack([mi, ai1], dim=1)
m_ac, rev_indices, counts = torch._unique_dim2_temporary_will_remove_soon(mi_ai1, dim=0, sorted=True, return_inverse=True, return_counts=True)
uniqued_molecule_index, uniqued_central_atom_index = m_ac.unbind(1)
uniqued_central_atom_index, rev_indices, counts = torch._unique2_temporary_will_remove_soon(ai1, sorted=True, return_inverse=True, return_counts=True)
# do local combinations within unique key, assuming sorted
pair_sizes = counts * (counts - 1) // 2
total_size = pair_sizes.sum()
molecule_index = torch.repeat_interleave(uniqued_molecule_index, pair_sizes)
central_atom_index = torch.repeat_interleave(uniqued_central_atom_index, pair_sizes)
cumsum = cumsum_from_zero(pair_sizes)
cumsum = torch.repeat_interleave(cumsum, pair_sizes)
sorted_local_pair_index = torch.arange(total_size, device=molecule_index.device) - cumsum
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts)
cumsum = torch.repeat_interleave(cumsum, pair_sizes)
......@@ -264,7 +261,7 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
sign2 = torch.where(local_index2 < n, torch.ones_like(local_index2), -torch.ones_like(local_index2))
pair_index1 = torch.where(local_index1 < n, local_index1, local_index1 - n)
pair_index2 = torch.where(local_index2 < n, local_index2, local_index2 - n)
return molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2
return central_atom_index, pair_index1, pair_index2, sign1, sign2
# torch.jit.script
......@@ -276,32 +273,34 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
num_species_pairs = angular_length // angular_sublength
cutoff = max(Rcr, Rca)
molecule_index, atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff)
species1 = species[molecule_index, atom_index1]
species2 = species[molecule_index, atom_index2]
atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff)
species = species.flatten()
coordinates = coordinates.flatten(0, 1)
species1 = species[atom_index1]
species2 = species[atom_index2]
shift_values = torch.mm(shifts.to(cell.dtype), cell)
vec = coordinates[molecule_index, atom_index1, :] - coordinates[molecule_index, atom_index2, :] + shift_values
vec = coordinates.index_select(0, atom_index1) - coordinates.index_select(0, atom_index2) + shift_values
distances = vec.norm(2, -1)
# compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength)
index1 = (molecule_index * num_atoms + atom_index1) * num_species + species2
index2 = (molecule_index * num_atoms + atom_index2) * num_species + species1
index1 = atom_index1 * num_species + species2
index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_)
radial_aev.index_add_(0, index2, radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# compute angular aev
molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(molecule_index, atom_index1, atom_index2)
central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(atom_index1, atom_index2)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype)
vec2 = vec.index_select(0, pair_index2) * sign2.unsqueeze(1).to(vec.dtype)
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength)
index = (molecule_index * num_atoms + central_atom_index) * num_species_pairs + triu_index[species1_, species2_]
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
return torch.cat([radial_aev, angular_aev], dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment