Unverified Commit 0d0e6d5b authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

bypass slow code (#420)

* bypass slow code

* delete reference to all_atoms which was unused

* delete whitespace

* make flake8 happy

* rerun

* actual change was off

* typo did not turn on bypassing code
parent 21c794a8
...@@ -170,6 +170,39 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor, ...@@ -170,6 +170,39 @@ def neighbor_pairs(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
return molecule_index + atom_index1, molecule_index + atom_index2, shifts return molecule_index + atom_index1, molecule_index + atom_index2, shifts
def neighbor_pairs_nopbc(padding_mask: Tensor, coordinates: Tensor, cell: Tensor,
shifts: Tensor, cutoff: float) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute pairs of atoms that are neighbors (doesn't use PBC)
This function bypasses the calculation of shifts and duplication
of atoms in order to make calculations faster
Arguments:
padding_mask (:class:`torch.Tensor`): boolean tensor of shape
(molecules, atoms) for padding mask. 1 == is padding.
coordinates (:class:`torch.Tensor`): tensor of shape
(molecules, atoms, 3) for atom coordinates.
cutoff (float): the cutoff inside which atoms are considered pairs
"""
coordinates = coordinates.detach()
current_device = coordinates.device
num_atoms = padding_mask.shape[1]
p1_all, p2_all = torch.triu_indices(num_atoms, num_atoms, 1,
device=current_device).unbind(0)
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all)).norm(2, -1)
padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all))
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
molecule_index *= num_atoms
atom_index1 = p1_all[pair_index] + molecule_index
atom_index2 = p2_all[pair_index] + molecule_index
# shifts
shifts = shifts.new_zeros((p1_all.shape[0], 3)).index_select(0, pair_index)
return atom_index1, atom_index2, shifts
def triu_index(num_species: int) -> Tensor: def triu_index(num_species: int) -> Tensor:
species1, species2 = torch.triu_indices(num_species, num_species).unbind(0) species1, species2 = torch.triu_indices(num_species, num_species).unbind(0)
pair_index = torch.arange(species1.shape[0], dtype=torch.long) pair_index = torch.arange(species1.shape[0], dtype=torch.long)
...@@ -240,15 +273,18 @@ def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor, ...@@ -240,15 +273,18 @@ def compute_aev(species: Tensor, coordinates: Tensor, cell: Tensor,
num_molecules = species.shape[0] num_molecules = species.shape[0]
num_atoms = species.shape[1] num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength num_species_pairs = angular_length // angular_sublength
# PBC calculation is bypassed if there are no shifts
atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr) if shifts.numel() == 1:
species = species.flatten() atom_index1, atom_index2, shifts = neighbor_pairs_nopbc(species == -1, coordinates, cell, shifts, Rcr)
else:
atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr)
coordinates = coordinates.flatten(0, 1) coordinates = coordinates.flatten(0, 1)
shift_values = shifts.to(cell.dtype) @ cell
vec = coordinates.index_select(0, atom_index1) - coordinates.index_select(0, atom_index2) + shift_values
species = species.flatten()
species1 = species[atom_index1] species1 = species[atom_index1]
species2 = species[atom_index2] species2 = species[atom_index2]
shift_values = shifts.to(cell.dtype) @ cell
vec = coordinates.index_select(0, atom_index1) - coordinates.index_select(0, atom_index2) + shift_values
distances = vec.norm(2, -1) distances = vec.norm(2, -1)
# compute radial aev # compute radial aev
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment