Unverified Commit 74221a98 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Use torch.combinations (#116)

parent 391a672c
...@@ -210,19 +210,11 @@ class AEVComputer(torch.nn.Module): ...@@ -210,19 +210,11 @@ class AEVComputer(torch.nn.Module):
return radial_terms, angular_terms, species_ return radial_terms, angular_terms, species_
def _combinations(self, tensor, dim=0): def _combinations(self, tensor, dim=0):
# TODO: remove this when combinations is merged into PyTorch
# https://github.com/pytorch/pytorch/pull/9393
n = tensor.shape[dim] n = tensor.shape[dim]
if n == 0: if n == 0:
return tensor, tensor return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device) r = torch.arange(n, dtype=torch.long, device=tensor.device)
grid_x, grid_y = torch.meshgrid([r, r]) index1, index2 = torch.combinations(r).unbind(-1)
index1 = grid_y.masked_select(
torch.triu(torch.ones(n, n, device=tensor.device),
diagonal=1) == 1)
index2 = grid_x.masked_select(
torch.triu(torch.ones(n, n, device=tensor.device),
diagonal=1) == 1)
return tensor.index_select(dim, index1), \ return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2) tensor.index_select(dim, index2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment