Unverified Commit 56eba454 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Code cleanups (#174)

parent 21fcd8f4
...@@ -214,7 +214,7 @@ class AEVComputer(torch.nn.Module): ...@@ -214,7 +214,7 @@ class AEVComputer(torch.nn.Module):
n = tensor.shape[dim] n = tensor.shape[dim]
if n == 0: if n == 0:
return tensor, tensor return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device) r = torch.arange(n, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1) index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \ return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2) tensor.index_select(dim, index2)
...@@ -268,9 +268,8 @@ class AEVComputer(torch.nn.Module): ...@@ -268,9 +268,8 @@ class AEVComputer(torch.nn.Module):
rev_indices = {present_species[i].item(): i rev_indices = {present_species[i].item(): i
for i in range(len(present_species))} for i in range(len(present_species))}
angular_aevs = [] angular_aevs = []
zero_angular_subaev = torch.zeros( zero_angular_subaev = self.EtaR.new_zeros(
conformations, atoms, self.angular_sublength(), conformations, atoms, self.angular_sublength())
dtype=self.EtaR.dtype, device=self.EtaR.device)
for s1, s2 in itertools.combinations_with_replacement( for s1, s2 in itertools.combinations_with_replacement(
range(self.num_species), 2): range(self.num_species), 2):
if s1 in rev_indices and s2 in rev_indices: if s1 in rev_indices and s2 in rev_indices:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment