Unverified Commit 1ddb371d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Don't compute shifts when no PBC (#449)



* Don't compute shifts when no PBC

* Update aev.py

* Don't pass cell and shifts when not needed (#451)

* Don't pass cell and shifts when not needed

* Update aev.py

* Update aev.py

* Update aev.py
Co-authored-by: default avatarFarhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
parent 1a27b07f
......@@ -168,7 +168,7 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
return molecule_index + atom_index12, shifts
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]:
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tensor:
"""Compute pairs of atoms that are neighbors (doesn't use PBC)
This function bypasses the calculation of shifts and duplication
......@@ -196,9 +196,7 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa
molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
atom_index12 = p12_all[:, pair_index] + molecule_index
# shifts
shifts = coordinates.new_zeros((pair_index.shape[0], 3))
return atom_index12, shifts
return atom_index12
def triu_index(num_species: int) -> Tensor:
......@@ -258,25 +256,29 @@ def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
return central_atom_index, local_index12 % n, sign12
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, triu_index: Tensor,
def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
sizes: Tuple[int, int, int, int, int]) -> Tensor:
sizes: Tuple[int, int, int, int, int], cell_shifts: Optional[Tuple[Tensor, Tensor]]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length = sizes
num_molecules = species.shape[0]
num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength
coordinates_ = coordinates
coordinates = coordinates_.flatten(0, 1)
# PBC calculation is bypassed if there are no shifts
if shifts.numel() == 0:
atom_index12, shifts = neighbor_pairs_nopbc(species == -1, coordinates, Rcr)
if cell_shifts is None:
atom_index12 = neighbor_pairs_nopbc(species == -1, coordinates_, Rcr)
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(2, -1, 3)
vec = selected_coordinates[0] - selected_coordinates[1]
else:
atom_index12, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr)
coordinates = coordinates.flatten(0, 1)
cell, shifts = cell_shifts
atom_index12, shifts = neighbor_pairs(species == -1, coordinates_, cell, shifts, Rcr)
shift_values = shifts.to(cell.dtype) @ cell
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1))
selected_coordinates = selected_coordinates.view(2, -1, 3)
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(2, -1, 3)
vec = selected_coordinates[0] - selected_coordinates[1] + shift_values
species = species.flatten()
species12 = species[atom_index12]
......@@ -433,12 +435,11 @@ class AEVComputer(torch.nn.Module):
species, coordinates = input_
if cell is None and pbc is None:
cell = self.default_cell
shifts = self.default_shifts
aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None)
else:
assert (cell is not None and pbc is not None)
cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff)
aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, (cell, shifts))
aev = compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
return SpeciesAEV(species, aev)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment