"src/vscode:/vscode.git/clone" did not exist on "a73f8b725105b12a60a9b22918bda68f8b6d26c3"
Commit 920666fe authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Accelerate angular AEV computation and reduce memory cost (#290)

* Accerate angular AEV computation and reduce memory cost

* reduce number of elementwise product
parent 560d37ac
...@@ -9,11 +9,8 @@ from typing import Tuple ...@@ -9,11 +9,8 @@ from typing import Tuple
# @torch.jit.script # @torch.jit.script
def cutoff_cosine(distances, cutoff): def cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
return torch.where( # assuming all elements in distances are smaller than cutoff
distances <= cutoff, return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
0.5 * torch.cos(math.pi * distances / cutoff) + 0.5,
torch.zeros_like(distances)
)
# @torch.jit.script # @torch.jit.script
...@@ -270,9 +267,8 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -270,9 +267,8 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
num_molecules = species.shape[0] num_molecules = species.shape[0]
num_atoms = species.shape[1] num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength num_species_pairs = angular_length // angular_sublength
cutoff = max(Rcr, Rca)
atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff) atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, Rcr)
species = species.flatten() species = species.flatten()
coordinates = coordinates.flatten(0, 1) coordinates = coordinates.flatten(0, 1)
species1 = species[atom_index1] species1 = species[atom_index1]
...@@ -291,6 +287,15 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -291,6 +287,15 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
radial_aev.index_add_(0, index2, radial_terms_) radial_aev.index_add_(0, index2, radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length) radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# Rca is usually much smaller than Rcr, using neighbor list with cutoff=Rcr is a waste of resources
# Now we will get a smaller neighbor list that only cares about atoms with distances <= Rca
even_closer_indices = (distances <= Rca).nonzero().flatten()
atom_index1 = atom_index1.index_select(0, even_closer_indices)
atom_index2 = atom_index2.index_select(0, even_closer_indices)
species1 = species1.index_select(0, even_closer_indices)
species2 = species2.index_select(0, even_closer_indices)
vec = vec.index_select(0, even_closer_indices)
# compute angular aev # compute angular aev
central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(atom_index1, atom_index2) central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(atom_index1, atom_index2)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype) vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype)
...@@ -338,6 +343,7 @@ class AEVComputer(torch.nn.Module): ...@@ -338,6 +343,7 @@ class AEVComputer(torch.nn.Module):
super(AEVComputer, self).__init__() super(AEVComputer, self).__init__()
self.Rcr = Rcr self.Rcr = Rcr
self.Rca = Rca self.Rca = Rca
assert Rca <= Rcr, "Current implementation of AEVComputer assumes Rca <= Rcr"
# convert constant tensors to a ready-to-broadcast shape # convert constant tensors to a ready-to-broadcast shape
# shape convension (..., EtaR, ShfR) # shape convension (..., EtaR, ShfR)
self.register_buffer('EtaR', EtaR.view(-1, 1)) self.register_buffer('EtaR', EtaR.view(-1, 1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment