Unverified Commit 8c493a6e authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

use 0 size n dim tensor feature to simplify code (#43)

parent 0e992fe5
...@@ -356,8 +356,7 @@ class SortedAEV(AEVComputer): ...@@ -356,8 +356,7 @@ class SortedAEV(AEVComputer):
vec = vec.gather(-2, _indices_a) vec = vec.gather(-2, _indices_a)
# TODO: can we move combinations to ATen? # TODO: can we move combinations to ATen?
vec = self.combinations(vec, -2) vec = self.combinations(vec, -2)
angular_terms = self.angular_subaev_terms( angular_terms = self.angular_subaev_terms(*vec)
*vec) if vec is not None else None
return radial_terms, angular_terms, indices_r, indices_a return radial_terms, angular_terms, indices_r, indices_a
...@@ -367,11 +366,6 @@ class SortedAEV(AEVComputer): ...@@ -367,11 +366,6 @@ class SortedAEV(AEVComputer):
grid_x, grid_y = torch.meshgrid([r, r]) grid_x, grid_y = torch.meshgrid([r, r])
index1 = grid_y[torch.triu(torch.ones(n, n), diagonal=1) == 1] index1 = grid_y[torch.triu(torch.ones(n, n), diagonal=1) == 1]
index2 = grid_x[torch.triu(torch.ones(n, n), diagonal=1) == 1] index2 = grid_x[torch.triu(torch.ones(n, n), diagonal=1) == 1]
if torch.numel(index1) == 0:
# TODO: pytorch are unable to handle size 0 tensor well.
# Is this an expected behavior?
# See: https://github.com/pytorch/pytorch/issues/5014
return None
return tensor.index_select(dim, index1), \ return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2) tensor.index_select(dim, index2)
...@@ -412,21 +406,16 @@ class SortedAEV(AEVComputer): ...@@ -412,21 +406,16 @@ class SortedAEV(AEVComputer):
present species) storing the mask for each pair. present species) storing the mask for each pair.
""" """
species_a = self.combinations(species_a, -1) species_a = self.combinations(species_a, -1)
if species_a is not None: species_a1, species_a2 = species_a
# TODO: can we remove this if pytorch support 0 size tensors?
species_a1, species_a2 = species_a mask_a1 = (species_a1.unsqueeze(-1) ==
present_species).unsqueeze(-1)
if species_a is not None: mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1)
mask_a1 = (species_a1.unsqueeze(-1) == == present_species)
present_species).unsqueeze(-1) mask = mask_a1 * mask_a2
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) mask_rev = mask.permute(0, 1, 2, 4, 3)
== present_species) mask_a = (mask + mask_rev) > 0
mask = mask_a1 * mask_a2 return mask_a
mask_rev = mask.permute(0, 1, 2, 4, 3)
mask_a = (mask + mask_rev) > 0
return mask_a
else:
return None
def assemble(self, radial_terms, angular_terms, present_species, def assemble(self, radial_terms, angular_terms, present_species,
mask_r, mask_a): mask_r, mask_a):
...@@ -480,8 +469,7 @@ class SortedAEV(AEVComputer): ...@@ -480,8 +469,7 @@ class SortedAEV(AEVComputer):
dtype=self.dtype, device=self.device) dtype=self.dtype, device=self.device)
for s1, s2 in itertools.combinations_with_replacement( for s1, s2 in itertools.combinations_with_replacement(
range(len(self.species)), 2): range(len(self.species)), 2):
# TODO: can we remove this if pytorch support 0 size tensors? if s1 in rev_indices and s2 in rev_indices:
if s1 in rev_indices and s2 in rev_indices and mask_a is not None:
i1 = rev_indices[s1] i1 = rev_indices[s1]
i2 = rev_indices[s2] i2 = rev_indices[s2]
mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.dtype) mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment