Unverified Commit 55419dfa authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

JIT more (#180)

parent e0411f49
......@@ -92,17 +92,6 @@ def time_func(key, func):
# enable timers
torchani.aev._radial_subaev_terms = time_func(
'radial terms', torchani.aev._radial_subaev_terms)
torchani.aev._angular_subaev_terms = time_func(
'angular terms', torchani.aev._angular_subaev_terms)
nnp[0]._terms_and_indices = time_func('terms and indices',
nnp[0]._terms_and_indices)
torchani.aev._compute_mask_r = time_func('mask_r',
torchani.aev._compute_mask_r)
torchani.aev._compute_mask_a = time_func('mask_a',
torchani.aev._compute_mask_a)
torchani.aev._assemble = time_func('assemble', torchani.aev._assemble)
nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward)
......@@ -110,12 +99,6 @@ nnp[1].forward = time_func('forward', nnp[1].forward)
start = timeit.default_timer()
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', timers['radial terms'])
print('Angular terms:', timers['angular terms'])
print('Terms and indices:', timers['terms and indices'])
print('Mask R:', timers['mask_r'])
print('Mask A:', timers['mask_a'])
print('Assemble:', timers['assemble'])
print('Total AEV:', timers['total'])
print('NN:', timers['forward'])
print('Epoch time:', elapsed)
......@@ -87,6 +87,38 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4)
@torch.jit.script
def _combinations(tensor, dim=0):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
@torch.jit.script
def _terms_and_indices(Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA,
distances, vec):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
# type: (float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
radial_terms = _radial_subaev_terms(Rcr, EtaR,
ShfR, distances)
vec = _combinations(vec, -2)
angular_terms = _angular_subaev_terms(Rca, ShfZ, EtaA,
Zeta, ShfA, *vec)
return radial_terms, angular_terms
@torch.jit.script
def default_neighborlist(species, coordinates, cutoff):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
......@@ -121,24 +153,13 @@ def default_neighborlist(species, coordinates, cutoff):
return neighbor_species, neighbor_distances, neighbor_coordinates
# @torch.jit.script
def _combinations(tensor, dim=0):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
# @torch.jit.script
@torch.jit.script
def _compute_mask_r(species_r, num_species):
# type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices"""
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(num_species, device=species_r.device))
torch.arange(num_species, dtype=torch.long,
device=species_r.device))
return mask_r
......@@ -154,7 +175,7 @@ def _compute_mask_a(species_a, present_species):
return mask_a
# @torch.jit.script
@torch.jit.script
def _assemble(radial_terms, angular_terms, present_species,
mask_r, mask_a, num_species, angular_sublength):
"""Returns radial and angular AEV computed from terms according
......@@ -172,32 +193,42 @@ def _assemble(radial_terms, angular_terms, present_species,
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, int, int) -> Tuple[Tensor, Tensor] # noqa: E501
conformations = radial_terms.shape[0]
atoms = radial_terms.shape[1]
# assemble radial subaev
present_radial_aevs = (
radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype)
mask_r.unsqueeze(-1).to(radial_terms.dtype)
).sum(-3)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev
rev_indices = present_species.new_full((num_species,), -1)
rev_indices = torch.full((num_species,), -1, dtype=present_species.dtype,
device=present_species.device)
rev_indices[present_species] = torch.arange(present_species.numel(),
dtype=torch.long,
device=radial_terms.device)
angular_aevs = []
zero_angular_subaev = radial_terms.new_zeros(
conformations, atoms, angular_sublength)
for s1, s2 in torch.combinations(
torch.arange(num_species, device=radial_terms.device),
2, with_replacement=True):
i1 = rev_indices[s1].item()
i2 = rev_indices[s2].item()
zero_angular_subaev = torch.zeros(conformations, atoms, angular_sublength,
dtype=radial_terms.dtype,
device=radial_terms.device)
for s1 in range(num_species):
# TODO: make PyTorch support range(start, end) and
# range(start, end, step) and remove the workaround
# below. The inner for loop should be:
# for s2 in range(s1, num_species):
for s2 in range(num_species - s1):
s2 += s1
i1 = int(rev_indices[s1])
i2 = int(rev_indices[s2])
if i1 >= 0 and i2 >= 0:
mask = mask_a[..., i1, i2].unsqueeze(-1).type(radial_terms.dtype)
mask = mask_a[:, :, :, i1, i2].unsqueeze(-1) \
.to(radial_terms.dtype)
subaev = (angular_terms * mask).sum(-2)
else:
subaev = zero_angular_subaev
......@@ -206,6 +237,25 @@ def _assemble(radial_terms, angular_terms, present_species,
return radial_aevs, torch.cat(angular_aevs, dim=2)
@torch.jit.script
def _compute_aev(num_species, angular_sublength, Rcr, EtaR, ShfR, Rca, ShfZ,
EtaA, Zeta, ShfA, species, species_, distances, vec):
# type: (int, int, float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
present_species = utils.present_species(species)
radial_terms, angular_terms = _terms_and_indices(
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA, distances, vec)
mask_r = _compute_mask_r(species_, num_species)
mask_a = _compute_mask_a(species_, present_species)
radial, angular = _assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a,
num_species, angular_sublength)
fullaev = torch.cat([radial, angular], dim=2)
return species, fullaev
class AEVComputer(torch.nn.Module):
r"""The AEV computer that takes coordinates as input and outputs aevs.
......@@ -245,6 +295,9 @@ class AEVComputer(torch.nn.Module):
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
__constants__ = ['Rcr', 'Rca', 'num_species', 'radial_sublength',
'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species, neighborlist_computer=default_neighborlist):
......@@ -277,24 +330,7 @@ class AEVComputer(torch.nn.Module):
# The length of full aev
self.aev_length = self.radial_length + self.angular_length
def _terms_and_indices(self, species, coordinates):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
max_cutoff = max(self.Rcr, self.Rca)
species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff)
radial_terms = _radial_subaev_terms(self.Rcr, self.EtaR,
self.ShfR, distances)
vec = _combinations(vec, -2)
angular_terms = _angular_subaev_terms(self.Rca, self.ShfZ, self.EtaA,
self.Zeta, self.ShfA, *vec)
return radial_terms, angular_terms, species_
# @torch.jit.script_method
def forward(self, species_coordinates):
"""Compute AEVs
......@@ -309,17 +345,13 @@ class AEVComputer(torch.nn.Module):
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
species, coordinates = species_coordinates
present_species = utils.present_species(species)
radial_terms, angular_terms, species_ = \
self._terms_and_indices(species, coordinates)
mask_r = _compute_mask_r(species_, self.num_species)
mask_a = _compute_mask_a(species_, present_species)
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
radial, angular = _assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a,
self.num_species, self.angular_sublength)
fullaev = torch.cat([radial, angular], dim=2)
return species, fullaev
species, coordinates = species_coordinates
max_cutoff = max(self.Rcr, self.Rca)
species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff)
return _compute_aev(
self.num_species, self.angular_sublength, self.Rcr, self.EtaR,
self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA,
species, species_, distances, vec)
......@@ -30,7 +30,7 @@ class Container(torch.nn.ModuleDict):
results = {k: [] for k in self}
for sx in species_x:
for k in self:
_, result = self[k](sx)
_, result = self[k](tuple(sx))
results[k].append(result)
for k in self:
results[k] = torch.cat(results[k])
......
......@@ -65,6 +65,7 @@ def pad_coordinates(species_coordinates):
return torch.cat(species), torch.cat(coordinates)
@torch.jit.script
def present_species(species):
"""Given a vector of species of atoms, compute the unique species present.
......@@ -74,8 +75,8 @@ def present_species(species):
Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted.
"""
present_species = species.flatten().unique(sorted=True)
if present_species[0].item() == -1:
present_species, _ = species.flatten()._unique(sorted=True)
if int(present_species[0]) == -1:
present_species = present_species[1:]
return present_species
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment