Unverified Commit 6f2a3d5f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Precompute default cells (#202)

parent 35531421
......@@ -268,7 +268,7 @@ def triple_by_molecule(molecule_index, atom_index1, atom_index2):
# torch.jit.script
def compute_aev(species, coordinates, cell, pbc_switch, triu_index, constants, sizes):
def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes):
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0]
......@@ -276,7 +276,6 @@ def compute_aev(species, coordinates, cell, pbc_switch, triu_index, constants, s
num_species_pairs = angular_length // angular_sublength
cutoff = max(Rcr, Rca)
shifts = compute_shifts(cell, pbc_switch, cutoff)
molecule_index, atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff)
species1 = species[molecule_index, atom_index1]
species2 = species[molecule_index, atom_index2]
......@@ -366,6 +365,15 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('triu_index', triu_index(num_species))
# Set up default cell and compute default shifts.
# These values are used when cell and pbc switch are not given.
cutoff = max(self.Rcr, self.Rca)
default_cell = torch.eye(3, dtype=self.EtaR.dtype, device=self.EtaR.device)
default_pbc = torch.zeros(3, dtype=torch.uint8, device=self.EtaR.device)
default_shifts = compute_shifts(default_cell, default_pbc, cutoff)
self.register_buffer('default_cell', default_cell)
self.register_buffer('default_shifts', default_shifts)
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
......@@ -403,9 +411,11 @@ class AEVComputer(torch.nn.Module):
"""
if len(input) == 2:
species, coordinates = input
cell = torch.eye(3, dtype=self.EtaR.dtype, device=self.EtaR.device)
pbc = torch.zeros(3, dtype=torch.uint8, device=self.EtaR.device)
cell = self.default_cell
shifts = self.default_shifts
else:
assert len(input) == 4
species, coordinates, cell, pbc = input
return species, compute_aev(species, coordinates, cell, pbc, self.triu_index, self.constants(), self.sizes)
cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff)
return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment