"vscode:/vscode.git/clone" did not exist on "4ad999d1440e896abec3f3c7029f292ce46cc820"
Unverified Commit e0411f49 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Move more computation ouside AEVComputer (#179)

parent 65bcbb45
...@@ -14,7 +14,7 @@ class TestAEV(unittest.TestCase): ...@@ -14,7 +14,7 @@ class TestAEV(unittest.TestCase):
def setUp(self): def setUp(self):
builtins = torchani.neurochem.Builtins() builtins = torchani.neurochem.Builtins()
self.aev_computer = builtins.aev_computer self.aev_computer = builtins.aev_computer
self.radial_length = self.aev_computer.radial_length() self.radial_length = self.aev_computer.radial_length
self.tolerance = 1e-5 self.tolerance = 1e-5
def random_skip(self): def random_skip(self):
......
...@@ -16,7 +16,7 @@ class TestNeuroChem(unittest.TestCase): ...@@ -16,7 +16,7 @@ class TestNeuroChem(unittest.TestCase):
trainer = torchani.neurochem.Trainer(iptpath, d, True, 'runs') trainer = torchani.neurochem.Trainer(iptpath, d, True, 'runs')
# test if loader construct correct model # test if loader construct correct model
self.assertEqual(trainer.aev_computer.aev_length(), 384) self.assertEqual(trainer.aev_computer.aev_length, 384)
m = trainer.model m = trainer.model
H, C, N, O = m # noqa: E741 H, C, N, O = m # noqa: E741
self.assertIsInstance(H[0], torch.nn.Linear) self.assertIsInstance(H[0], torch.nn.Linear)
......
...@@ -98,10 +98,11 @@ torchani.aev._angular_subaev_terms = time_func( ...@@ -98,10 +98,11 @@ torchani.aev._angular_subaev_terms = time_func(
'angular terms', torchani.aev._angular_subaev_terms) 'angular terms', torchani.aev._angular_subaev_terms)
nnp[0]._terms_and_indices = time_func('terms and indices', nnp[0]._terms_and_indices = time_func('terms and indices',
nnp[0]._terms_and_indices) nnp[0]._terms_and_indices)
nnp[0]._combinations = time_func('combinations', nnp[0]._combinations) torchani.aev._compute_mask_r = time_func('mask_r',
nnp[0]._compute_mask_r = time_func('mask_r', nnp[0]._compute_mask_r) torchani.aev._compute_mask_r)
nnp[0]._compute_mask_a = time_func('mask_a', nnp[0]._compute_mask_a) torchani.aev._compute_mask_a = time_func('mask_a',
nnp[0]._assemble = time_func('assemble', nnp[0]._assemble) torchani.aev._compute_mask_a)
torchani.aev._assemble = time_func('assemble', torchani.aev._assemble)
nnp[0].forward = time_func('total', nnp[0].forward) nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward) nnp[1].forward = time_func('forward', nnp[1].forward)
...@@ -112,7 +113,6 @@ elapsed = round(timeit.default_timer() - start, 2) ...@@ -112,7 +113,6 @@ elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', timers['radial terms']) print('Radial terms:', timers['radial terms'])
print('Angular terms:', timers['angular terms']) print('Angular terms:', timers['angular terms'])
print('Terms and indices:', timers['terms and indices']) print('Terms and indices:', timers['terms and indices'])
print('Combinations:', timers['combinations'])
print('Mask R:', timers['mask_r']) print('Mask R:', timers['mask_r'])
print('Mask A:', timers['mask_a']) print('Mask A:', timers['mask_a'])
print('Assemble:', timers['assemble']) print('Assemble:', timers['assemble'])
......
...@@ -121,6 +121,91 @@ def default_neighborlist(species, coordinates, cutoff): ...@@ -121,6 +121,91 @@ def default_neighborlist(species, coordinates, cutoff):
return neighbor_species, neighbor_distances, neighbor_coordinates return neighbor_species, neighbor_distances, neighbor_coordinates
# @torch.jit.script
def _combinations(tensor, dim=0):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
# @torch.jit.script
def _compute_mask_r(species_r, num_species):
# type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices"""
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(num_species, device=species_r.device))
return mask_r
@torch.jit.script
def _compute_mask_a(species_a, present_species):
"""Get mask of angular terms for each supported species from indices"""
species_a1, species_a2 = _combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
mask = mask_a1 & mask_a2
mask_rev = mask.permute(0, 1, 2, 4, 3)
mask_a = mask | mask_rev
return mask_a
# @torch.jit.script
def _assemble(radial_terms, angular_terms, present_species,
mask_r, mask_a, num_species, angular_sublength):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
conformations = radial_terms.shape[0]
atoms = radial_terms.shape[1]
# assemble radial subaev
present_radial_aevs = (
radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype)
).sum(-3)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev
rev_indices = present_species.new_full((num_species,), -1)
rev_indices[present_species] = torch.arange(present_species.numel(),
device=radial_terms.device)
angular_aevs = []
zero_angular_subaev = radial_terms.new_zeros(
conformations, atoms, angular_sublength)
for s1, s2 in torch.combinations(
torch.arange(num_species, device=radial_terms.device),
2, with_replacement=True):
i1 = rev_indices[s1].item()
i2 = rev_indices[s2].item()
if i1 >= 0 and i2 >= 0:
mask = mask_a[..., i1, i2].unsqueeze(-1).type(radial_terms.dtype)
subaev = (angular_terms * mask).sum(-2)
else:
subaev = zero_angular_subaev
angular_aevs.append(subaev)
return radial_aevs, torch.cat(angular_aevs, dim=2)
class AEVComputer(torch.nn.Module): class AEVComputer(torch.nn.Module):
r"""The AEV computer that takes coordinates as input and outputs aevs. r"""The AEV computer that takes coordinates as input and outputs aevs.
...@@ -179,27 +264,18 @@ class AEVComputer(torch.nn.Module): ...@@ -179,27 +264,18 @@ class AEVComputer(torch.nn.Module):
self.num_species = num_species self.num_species = num_species
self.neighborlist = neighborlist_computer self.neighborlist = neighborlist_computer
def radial_sublength(self): # The length of radial subaev of a single species
"""Returns the length of radial subaev of a single species""" self.radial_sublength = self.EtaR.numel() * self.ShfR.numel()
return self.EtaR.numel() * self.ShfR.numel() # The length of full radial aev
self.radial_length = self.num_species * self.radial_sublength
def radial_length(self): # The length of angular subaev of a single species
"""Returns the length of full radial aev""" self.angular_sublength = self.EtaA.numel() * self.Zeta.numel() * \
return self.num_species * self.radial_sublength() self.ShfA.numel() * self.ShfZ.numel()
# The length of full angular aev
def angular_sublength(self): self.angular_length = (self.num_species * (self.num_species + 1)) \
"""Returns the length of angular subaev of a single species""" // 2 * self.angular_sublength
return self.EtaA.numel() * self.Zeta.numel() * self.ShfA.numel() * \ # The length of full aev
self.ShfZ.numel() self.aev_length = self.radial_length + self.angular_length
def angular_length(self):
"""Returns the length of full angular aev"""
s = self.num_species
return (s * (s + 1)) // 2 * self.angular_sublength()
def aev_length(self):
"""Returns the length of full aev"""
return self.radial_length() + self.angular_length()
def _terms_and_indices(self, species, coordinates): def _terms_and_indices(self, species, coordinates):
"""Returns radial and angular subAEV terms, these terms will be sorted """Returns radial and angular subAEV terms, these terms will be sorted
...@@ -213,88 +289,12 @@ class AEVComputer(torch.nn.Module): ...@@ -213,88 +289,12 @@ class AEVComputer(torch.nn.Module):
radial_terms = _radial_subaev_terms(self.Rcr, self.EtaR, radial_terms = _radial_subaev_terms(self.Rcr, self.EtaR,
self.ShfR, distances) self.ShfR, distances)
vec = self._combinations(vec, -2) vec = _combinations(vec, -2)
angular_terms = _angular_subaev_terms(self.Rca, self.ShfZ, self.EtaA, angular_terms = _angular_subaev_terms(self.Rca, self.ShfZ, self.EtaA,
self.Zeta, self.ShfA, *vec) self.Zeta, self.ShfA, *vec)
return radial_terms, angular_terms, species_ return radial_terms, angular_terms, species_
def _combinations(self, tensor, dim=0):
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
def _compute_mask_r(self, species_r):
"""Get mask of radial terms for each supported species from indices"""
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(self.num_species, device=self.EtaR.device))
return mask_r
def _compute_mask_a(self, species_a, present_species):
"""Get mask of angular terms for each supported species from indices"""
species_a1, species_a2 = self._combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
mask = mask_a1 & mask_a2
mask_rev = mask.permute(0, 1, 2, 4, 3)
mask_a = mask | mask_rev
return mask_a
def _assemble(self, radial_terms, angular_terms, present_species,
mask_r, mask_a):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
conformations = radial_terms.shape[0]
atoms = radial_terms.shape[1]
# assemble radial subaev
present_radial_aevs = (
radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype)
).sum(-3)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev
rev_indices = self.EtaR.new_full((self.num_species,),
-1, dtype=torch.int64)
rev_indices[present_species] = torch.arange(present_species.numel(),
device=self.EtaR.device)
angular_aevs = []
zero_angular_subaev = self.EtaR.new_zeros(
conformations, atoms, self.angular_sublength())
for s1, s2 in torch.combinations(
torch.arange(self.num_species, device=self.EtaR.device),
2, with_replacement=True):
i1 = rev_indices[s1].item()
i2 = rev_indices[s2].item()
if i1 >= 0 and i2 >= 0:
mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.EtaR.dtype)
subaev = (angular_terms * mask).sum(-2)
else:
subaev = zero_angular_subaev
angular_aevs.append(subaev)
return radial_aevs, torch.cat(angular_aevs, dim=2)
def forward(self, species_coordinates): def forward(self, species_coordinates):
"""Compute AEVs """Compute AEVs
...@@ -315,10 +315,11 @@ class AEVComputer(torch.nn.Module): ...@@ -315,10 +315,11 @@ class AEVComputer(torch.nn.Module):
radial_terms, angular_terms, species_ = \ radial_terms, angular_terms, species_ = \
self._terms_and_indices(species, coordinates) self._terms_and_indices(species, coordinates)
mask_r = self._compute_mask_r(species_) mask_r = _compute_mask_r(species_, self.num_species)
mask_a = self._compute_mask_a(species_, present_species) mask_a = _compute_mask_a(species_, present_species)
radial, angular = self._assemble(radial_terms, angular_terms, radial, angular = _assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a) present_species, mask_r, mask_a,
self.num_species, self.angular_sublength)
fullaev = torch.cat([radial, angular], dim=2) fullaev = torch.cat([radial, angular], dim=2)
return species, fullaev return species, fullaev
...@@ -588,7 +588,7 @@ if sys.version_info[0] > 2: ...@@ -588,7 +588,7 @@ if sys.version_info[0] > 2:
# construct networks # construct networks
input_size, network_setup = network_setup input_size, network_setup = network_setup
if input_size != self.aev_computer.aev_length(): if input_size != self.aev_computer.aev_length:
raise ValueError('AEV size and input size does not match') raise ValueError('AEV size and input size does not match')
l2reg = [] l2reg = []
atomic_nets = {} atomic_nets = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment