Unverified Commit 276a886d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Optimize sort and repeat_interleave in triple_by_molecule (#207)

parent bacde2e7
......@@ -234,27 +234,28 @@ def triple_by_molecule(atom_index1, atom_index2):
# convert representation from pair to central-others
n = atom_index1.shape[0]
ai1 = torch.cat([atom_index1, atom_index2])
sorted_ai1, rev_indices = ai1.sort()
# sort and compute unique key
uniqued_central_atom_index, rev_indices, counts = torch._unique2_temporary_will_remove_soon(ai1, sorted=True, return_inverse=True, return_counts=True)
uniqued_central_atom_index, _, counts = torch._unique2_temporary_will_remove_soon(sorted_ai1, sorted=True, return_inverse=False, return_counts=True)
# do local combinations within unique key, assuming sorted
pair_sizes = counts * (counts - 1) // 2
total_size = pair_sizes.sum()
central_atom_index = torch.repeat_interleave(uniqued_central_atom_index, pair_sizes)
pair_indices = torch.repeat_interleave(pair_sizes)
central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)
cumsum = cumsum_from_zero(pair_sizes)
cumsum = torch.repeat_interleave(cumsum, pair_sizes)
cumsum = cumsum.index_select(0, pair_indices)
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts)
cumsum = torch.repeat_interleave(cumsum, pair_sizes)
cumsum = cumsum.index_select(0, pair_indices)
sorted_local_index1 += cumsum
sorted_local_index2 += cumsum
# unsort result from last part
argsort = rev_indices.argsort()
local_index1 = argsort[sorted_local_index1]
local_index2 = argsort[sorted_local_index2]
local_index1 = rev_indices[sorted_local_index1]
local_index2 = rev_indices[sorted_local_index2]
# compute mapping between representation of central-other to pair
sign1 = torch.where(local_index1 < n, torch.ones_like(local_index1), -torch.ones_like(local_index1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment