Unverified Commit 55419dfa authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

JIT more (#180)

parent e0411f49
...@@ -92,17 +92,6 @@ def time_func(key, func): ...@@ -92,17 +92,6 @@ def time_func(key, func):
# enable timers # enable timers
torchani.aev._radial_subaev_terms = time_func(
'radial terms', torchani.aev._radial_subaev_terms)
torchani.aev._angular_subaev_terms = time_func(
'angular terms', torchani.aev._angular_subaev_terms)
nnp[0]._terms_and_indices = time_func('terms and indices',
nnp[0]._terms_and_indices)
torchani.aev._compute_mask_r = time_func('mask_r',
torchani.aev._compute_mask_r)
torchani.aev._compute_mask_a = time_func('mask_a',
torchani.aev._compute_mask_a)
torchani.aev._assemble = time_func('assemble', torchani.aev._assemble)
nnp[0].forward = time_func('total', nnp[0].forward) nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward) nnp[1].forward = time_func('forward', nnp[1].forward)
...@@ -110,12 +99,6 @@ nnp[1].forward = time_func('forward', nnp[1].forward) ...@@ -110,12 +99,6 @@ nnp[1].forward = time_func('forward', nnp[1].forward)
start = timeit.default_timer() start = timeit.default_timer()
trainer.run(dataset, max_epochs=1) trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
print('Radial terms:', timers['radial terms'])
print('Angular terms:', timers['angular terms'])
print('Terms and indices:', timers['terms and indices'])
print('Mask R:', timers['mask_r'])
print('Mask A:', timers['mask_a'])
print('Assemble:', timers['assemble'])
print('Total AEV:', timers['total']) print('Total AEV:', timers['total'])
print('NN:', timers['forward']) print('NN:', timers['forward'])
print('Epoch time:', elapsed) print('Epoch time:', elapsed)
...@@ -87,6 +87,38 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): ...@@ -87,6 +87,38 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4) return ret.flatten(start_dim=-4)
@torch.jit.script
def _combinations(tensor, dim=0):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
@torch.jit.script
def _terms_and_indices(Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA,
distances, vec):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
# type: (float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
radial_terms = _radial_subaev_terms(Rcr, EtaR,
ShfR, distances)
vec = _combinations(vec, -2)
angular_terms = _angular_subaev_terms(Rca, ShfZ, EtaA,
Zeta, ShfA, *vec)
return radial_terms, angular_terms
@torch.jit.script @torch.jit.script
def default_neighborlist(species, coordinates, cutoff): def default_neighborlist(species, coordinates, cutoff):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor] # type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
...@@ -121,24 +153,13 @@ def default_neighborlist(species, coordinates, cutoff): ...@@ -121,24 +153,13 @@ def default_neighborlist(species, coordinates, cutoff):
return neighbor_species, neighbor_distances, neighbor_coordinates return neighbor_species, neighbor_distances, neighbor_coordinates
# @torch.jit.script @torch.jit.script
def _combinations(tensor, dim=0):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
# @torch.jit.script
def _compute_mask_r(species_r, num_species): def _compute_mask_r(species_r, num_species):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices""" """Get mask of radial terms for each supported species from indices"""
mask_r = (species_r.unsqueeze(-1) == mask_r = (species_r.unsqueeze(-1) ==
torch.arange(num_species, device=species_r.device)) torch.arange(num_species, dtype=torch.long,
device=species_r.device))
return mask_r return mask_r
...@@ -154,7 +175,7 @@ def _compute_mask_a(species_a, present_species): ...@@ -154,7 +175,7 @@ def _compute_mask_a(species_a, present_species):
return mask_a return mask_a
# @torch.jit.script @torch.jit.script
def _assemble(radial_terms, angular_terms, present_species, def _assemble(radial_terms, angular_terms, present_species,
mask_r, mask_a, num_species, angular_sublength): mask_r, mask_a, num_species, angular_sublength):
"""Returns radial and angular AEV computed from terms according """Returns radial and angular AEV computed from terms according
...@@ -172,40 +193,69 @@ def _assemble(radial_terms, angular_terms, present_species, ...@@ -172,40 +193,69 @@ def _assemble(radial_terms, angular_terms, present_species,
mask_a (:class:`torch.Tensor`): shape (conformations, atoms, mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species) pairs, present species, present species)
""" """
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, int, int) -> Tuple[Tensor, Tensor] # noqa: E501
conformations = radial_terms.shape[0] conformations = radial_terms.shape[0]
atoms = radial_terms.shape[1] atoms = radial_terms.shape[1]
# assemble radial subaev # assemble radial subaev
present_radial_aevs = ( present_radial_aevs = (
radial_terms.unsqueeze(-2) * radial_terms.unsqueeze(-2) *
mask_r.unsqueeze(-1).type(radial_terms.dtype) mask_r.unsqueeze(-1).to(radial_terms.dtype)
).sum(-3) ).sum(-3)
# present_radial_aevs has shape # present_radial_aevs has shape
# (conformations, atoms, present species, radial_length) # (conformations, atoms, present species, radial_length)
radial_aevs = present_radial_aevs.flatten(start_dim=2) radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev # assemble angular subaev
rev_indices = present_species.new_full((num_species,), -1) rev_indices = torch.full((num_species,), -1, dtype=present_species.dtype,
device=present_species.device)
rev_indices[present_species] = torch.arange(present_species.numel(), rev_indices[present_species] = torch.arange(present_species.numel(),
dtype=torch.long,
device=radial_terms.device) device=radial_terms.device)
angular_aevs = [] angular_aevs = []
zero_angular_subaev = radial_terms.new_zeros( zero_angular_subaev = torch.zeros(conformations, atoms, angular_sublength,
conformations, atoms, angular_sublength) dtype=radial_terms.dtype,
for s1, s2 in torch.combinations( device=radial_terms.device)
torch.arange(num_species, device=radial_terms.device), for s1 in range(num_species):
2, with_replacement=True): # TODO: make PyTorch support range(start, end) and
i1 = rev_indices[s1].item() # range(start, end, step) and remove the workaround
i2 = rev_indices[s2].item() # below. The inner for loop should be:
if i1 >= 0 and i2 >= 0: # for s2 in range(s1, num_species):
mask = mask_a[..., i1, i2].unsqueeze(-1).type(radial_terms.dtype) for s2 in range(num_species - s1):
subaev = (angular_terms * mask).sum(-2) s2 += s1
else: i1 = int(rev_indices[s1])
subaev = zero_angular_subaev i2 = int(rev_indices[s2])
angular_aevs.append(subaev) if i1 >= 0 and i2 >= 0:
mask = mask_a[:, :, :, i1, i2].unsqueeze(-1) \
.to(radial_terms.dtype)
subaev = (angular_terms * mask).sum(-2)
else:
subaev = zero_angular_subaev
angular_aevs.append(subaev)
return radial_aevs, torch.cat(angular_aevs, dim=2) return radial_aevs, torch.cat(angular_aevs, dim=2)
@torch.jit.script
def _compute_aev(num_species, angular_sublength, Rcr, EtaR, ShfR, Rca, ShfZ,
EtaA, Zeta, ShfA, species, species_, distances, vec):
# type: (int, int, float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
present_species = utils.present_species(species)
radial_terms, angular_terms = _terms_and_indices(
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA, distances, vec)
mask_r = _compute_mask_r(species_, num_species)
mask_a = _compute_mask_a(species_, present_species)
radial, angular = _assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a,
num_species, angular_sublength)
fullaev = torch.cat([radial, angular], dim=2)
return species, fullaev
class AEVComputer(torch.nn.Module): class AEVComputer(torch.nn.Module):
r"""The AEV computer that takes coordinates as input and outputs aevs. r"""The AEV computer that takes coordinates as input and outputs aevs.
...@@ -245,6 +295,9 @@ class AEVComputer(torch.nn.Module): ...@@ -245,6 +295,9 @@ class AEVComputer(torch.nn.Module):
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
__constants__ = ['Rcr', 'Rca', 'num_species', 'radial_sublength',
'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species, neighborlist_computer=default_neighborlist): num_species, neighborlist_computer=default_neighborlist):
...@@ -277,24 +330,7 @@ class AEVComputer(torch.nn.Module): ...@@ -277,24 +330,7 @@ class AEVComputer(torch.nn.Module):
# The length of full aev # The length of full aev
self.aev_length = self.radial_length + self.angular_length self.aev_length = self.radial_length + self.angular_length
def _terms_and_indices(self, species, coordinates): # @torch.jit.script_method
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
max_cutoff = max(self.Rcr, self.Rca)
species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff)
radial_terms = _radial_subaev_terms(self.Rcr, self.EtaR,
self.ShfR, distances)
vec = _combinations(vec, -2)
angular_terms = _angular_subaev_terms(self.Rca, self.ShfZ, self.EtaA,
self.Zeta, self.ShfA, *vec)
return radial_terms, angular_terms, species_
def forward(self, species_coordinates): def forward(self, species_coordinates):
"""Compute AEVs """Compute AEVs
...@@ -309,17 +345,13 @@ class AEVComputer(torch.nn.Module): ...@@ -309,17 +345,13 @@ class AEVComputer(torch.nn.Module):
unchanged, and AEVs is a tensor of shape unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())`` ``(C, A, self.aev_length())``
""" """
species, coordinates = species_coordinates # type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
present_species = utils.present_species(species) species, coordinates = species_coordinates
max_cutoff = max(self.Rcr, self.Rca)
radial_terms, angular_terms, species_ = \ species_, distances, vec = self.neighborlist(species, coordinates,
self._terms_and_indices(species, coordinates) max_cutoff)
mask_r = _compute_mask_r(species_, self.num_species) return _compute_aev(
mask_a = _compute_mask_a(species_, present_species) self.num_species, self.angular_sublength, self.Rcr, self.EtaR,
self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA,
radial, angular = _assemble(radial_terms, angular_terms, species, species_, distances, vec)
present_species, mask_r, mask_a,
self.num_species, self.angular_sublength)
fullaev = torch.cat([radial, angular], dim=2)
return species, fullaev
...@@ -30,7 +30,7 @@ class Container(torch.nn.ModuleDict): ...@@ -30,7 +30,7 @@ class Container(torch.nn.ModuleDict):
results = {k: [] for k in self} results = {k: [] for k in self}
for sx in species_x: for sx in species_x:
for k in self: for k in self:
_, result = self[k](sx) _, result = self[k](tuple(sx))
results[k].append(result) results[k].append(result)
for k in self: for k in self:
results[k] = torch.cat(results[k]) results[k] = torch.cat(results[k])
......
...@@ -65,6 +65,7 @@ def pad_coordinates(species_coordinates): ...@@ -65,6 +65,7 @@ def pad_coordinates(species_coordinates):
return torch.cat(species), torch.cat(coordinates) return torch.cat(species), torch.cat(coordinates)
@torch.jit.script
def present_species(species): def present_species(species):
"""Given a vector of species of atoms, compute the unique species present. """Given a vector of species of atoms, compute the unique species present.
...@@ -74,8 +75,8 @@ def present_species(species): ...@@ -74,8 +75,8 @@ def present_species(species):
Returns: Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted. :class:`torch.Tensor`: 1D vector storing present atom types sorted.
""" """
present_species = species.flatten().unique(sorted=True) present_species, _ = species.flatten()._unique(sorted=True)
if present_species[0].item() == -1: if int(present_species[0]) == -1:
present_species = present_species[1:] present_species = present_species[1:]
return present_species return present_species
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment