Commit ae2497de authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Simplify code using torch.triu_indices (#367)

* Simplify code using torch.triu_indices

* Update aev.py
parent 2ba9ae00
......@@ -168,8 +168,7 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
def triu_index(num_species):
# type: (int) -> torch.Tensor
species = torch.arange(num_species, dtype=torch.long)
species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1)
species1, species2 = torch.triu_indices(num_species, num_species).unbind(0)
pair_index = torch.arange(species1.shape[0], dtype=torch.long)
ret = torch.zeros(num_species, num_species, dtype=torch.long)
ret[species1, species2] = pair_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment