Unverified Commit 220982fb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove dict comprehension and itertools from AEVComputer._assemble (#177)

parent ff1065ba
import torch import torch
import itertools
from . import _six # noqa:F401 from . import _six # noqa:F401
import math import math
from . import utils from . import utils
...@@ -265,16 +264,19 @@ class AEVComputer(torch.nn.Module): ...@@ -265,16 +264,19 @@ class AEVComputer(torch.nn.Module):
radial_aevs = present_radial_aevs.flatten(start_dim=2) radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev # assemble angular subaev
rev_indices = {present_species[i].item(): i rev_indices = self.EtaR.new_full((self.num_species,),
for i in range(len(present_species))} -1, dtype=torch.int64)
rev_indices[present_species] = torch.arange(present_species.numel(),
device=self.EtaR.device)
angular_aevs = [] angular_aevs = []
zero_angular_subaev = self.EtaR.new_zeros( zero_angular_subaev = self.EtaR.new_zeros(
conformations, atoms, self.angular_sublength()) conformations, atoms, self.angular_sublength())
for s1, s2 in itertools.combinations_with_replacement( for s1, s2 in torch.combinations(
range(self.num_species), 2): torch.arange(self.num_species, device=self.EtaR.device),
if s1 in rev_indices and s2 in rev_indices: 2, with_replacement=True):
i1 = rev_indices[s1] i1 = rev_indices[s1].item()
i2 = rev_indices[s2] i2 = rev_indices[s2].item()
if i1 >= 0 and i2 >= 0:
mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.EtaR.dtype) mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.EtaR.dtype)
subaev = (angular_terms * mask).sum(-2) subaev = (angular_terms * mask).sum(-2)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment