Unverified Commit 1ddb371d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Don't compute shifts when no PBC (#449)



* Don't compute shifts when no PBC

* Update aev.py

* Don't pass cell and shifts when not needed (#451)

* Don't pass cell and shifts when not needed

* Update aev.py

* Update aev.py

* Update aev.py
Co-authored-by: default avatarFarhad Ramezanghorbani <farhadrgh@users.noreply.github.com>
parent 1a27b07f
...@@ -168,7 +168,7 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, ...@@ -168,7 +168,7 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
return molecule_index + atom_index12, shifts return molecule_index + atom_index12, shifts
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tuple[Tensor, Tensor]: def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: float) -> Tensor:
"""Compute pairs of atoms that are neighbors (doesn't use PBC) """Compute pairs of atoms that are neighbors (doesn't use PBC)
This function bypasses the calculation of shifts and duplication This function bypasses the calculation of shifts and duplication
...@@ -196,9 +196,7 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa ...@@ -196,9 +196,7 @@ def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cutoff: floa
molecule_index, pair_index = in_cutoff.unbind(1) molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms molecule_index *= num_atoms
atom_index12 = p12_all[:, pair_index] + molecule_index atom_index12 = p12_all[:, pair_index] + molecule_index
# shifts return atom_index12
shifts = coordinates.new_zeros((pair_index.shape[0], 3))
return atom_index12, shifts
def triu_index(num_species: int) -> Tensor: def triu_index(num_species: int) -> Tensor:
...@@ -258,25 +256,29 @@ def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]: ...@@ -258,25 +256,29 @@ def triple_by_molecule(atom_index12: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
return central_atom_index, local_index12 % n, sign12 return central_atom_index, local_index12 % n, sign12
def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor, def compute_aev(species: Tensor, coordinates: Tensor, triu_index: Tensor,
shifts: Tensor, triu_index: Tensor,
constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor], constants: Tuple[float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor],
sizes: Tuple[int, int, int, int, int]) -> Tensor: sizes: Tuple[int, int, int, int, int], cell_shifts: Optional[Tuple[Tensor, Tensor]]) -> Tensor:
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length = sizes num_species, radial_sublength, radial_length, angular_sublength, angular_length = sizes
num_molecules = species.shape[0] num_molecules = species.shape[0]
num_atoms = species.shape[1] num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength num_species_pairs = angular_length // angular_sublength
coordinates_ = coordinates
coordinates = coordinates_.flatten(0, 1)
# PBC calculation is bypassed if there are no shifts # PBC calculation is bypassed if there are no shifts
if shifts.numel() == 0: if cell_shifts is None:
atom_index12, shifts = neighbor_pairs_nopbc(species == -1, coordinates, Rcr) atom_index12 = neighbor_pairs_nopbc(species == -1, coordinates_, Rcr)
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(2, -1, 3)
vec = selected_coordinates[0] - selected_coordinates[1]
else: else:
atom_index12, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr) cell, shifts = cell_shifts
coordinates = coordinates.flatten(0, 1) atom_index12, shifts = neighbor_pairs(species == -1, coordinates_, cell, shifts, Rcr)
shift_values = shifts.to(cell.dtype) @ cell shift_values = shifts.to(cell.dtype) @ cell
selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)) selected_coordinates = coordinates.index_select(0, atom_index12.view(-1)).view(2, -1, 3)
selected_coordinates = selected_coordinates.view(2, -1, 3) vec = selected_coordinates[0] - selected_coordinates[1] + shift_values
vec = selected_coordinates[0] - selected_coordinates[1] + shift_values
species = species.flatten() species = species.flatten()
species12 = species[atom_index12] species12 = species[atom_index12]
...@@ -433,12 +435,11 @@ class AEVComputer(torch.nn.Module): ...@@ -433,12 +435,11 @@ class AEVComputer(torch.nn.Module):
species, coordinates = input_ species, coordinates = input_
if cell is None and pbc is None: if cell is None and pbc is None:
cell = self.default_cell aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, None)
shifts = self.default_shifts
else: else:
assert (cell is not None and pbc is not None) assert (cell is not None and pbc is not None)
cutoff = max(self.Rcr, self.Rca) cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff) shifts = compute_shifts(cell, pbc, cutoff)
aev = compute_aev(species, coordinates, self.triu_index, self.constants(), self.sizes, (cell, shifts))
aev = compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
return SpeciesAEV(species, aev) return SpeciesAEV(species, aev)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment