Unverified Commit 220982fb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove dict comprehension and itertools from AEVComputer._assemble (#177)

parent ff1065ba
import torch
import itertools
from . import _six # noqa:F401
import math
from . import utils
......@@ -265,16 +264,19 @@ class AEVComputer(torch.nn.Module):
radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev
rev_indices = {present_species[i].item(): i
for i in range(len(present_species))}
rev_indices = self.EtaR.new_full((self.num_species,),
-1, dtype=torch.int64)
rev_indices[present_species] = torch.arange(present_species.numel(),
device=self.EtaR.device)
angular_aevs = []
zero_angular_subaev = self.EtaR.new_zeros(
conformations, atoms, self.angular_sublength())
for s1, s2 in itertools.combinations_with_replacement(
range(self.num_species), 2):
if s1 in rev_indices and s2 in rev_indices:
i1 = rev_indices[s1]
i2 = rev_indices[s2]
for s1, s2 in torch.combinations(
torch.arange(self.num_species, device=self.EtaR.device),
2, with_replacement=True):
i1 = rev_indices[s1].item()
i2 = rev_indices[s2].item()
if i1 >= 0 and i2 >= 0:
mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.EtaR.dtype)
subaev = (angular_terms * mask).sum(-2)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment