Unverified Commit 65bcbb45 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Enable jit on some functions (#178)

parent e4994ad5
from setuptools import setup, find_packages
import sys
setup_attrs = {
'name': 'torchani',
......@@ -27,4 +28,7 @@ setup_attrs = {
],
}
if sys.version_info[0] < 3:
setup_attrs['install_requires'].append('typing')
setup(**setup_attrs)
......@@ -92,10 +92,10 @@ def time_func(key, func):
# enable timers
nnp[0]._radial_subaev_terms = time_func('radial terms',
nnp[0]._radial_subaev_terms)
nnp[0]._angular_subaev_terms = time_func('angular terms',
nnp[0]._angular_subaev_terms)
torchani.aev._radial_subaev_terms = time_func(
'radial terms', torchani.aev._radial_subaev_terms)
torchani.aev._angular_subaev_terms = time_func(
'angular terms', torchani.aev._angular_subaev_terms)
nnp[0]._terms_and_indices = time_func('terms and indices',
nnp[0]._terms_and_indices)
nnp[0]._combinations = time_func('combinations', nnp[0]._combinations)
......
from __future__ import division
import torch
from . import _six # noqa:F401
import math
from . import utils
from torch import Tensor
from typing import Tuple
@torch.jit.script
def _cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor
return torch.where(
distances <= cutoff,
0.5 * torch.cos(math.pi * distances / cutoff) + 0.5,
......@@ -12,7 +17,79 @@ def _cutoff_cosine(distances, cutoff):
)
@torch.jit.script
def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
# type: (float, Tensor, Tensor, Tensor) -> Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where ``N``
is the number of neighbor atoms within the cutoff radius and output
tensor should have shape
(conformations, atoms, ``self.radial_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances = distances.unsqueeze(-1).unsqueeze(-1)
fc = _cutoff_cosine(distances, Rcr)
# Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-EtaR * (distances - ShfR)**2) * fc
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-2)
@torch.jit.script
def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where N
is the number of neighbor atom pairs within the cutoff radius and
output tensor should have shape
(conformations, atoms, ``self.angular_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1 = vectors1.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * \
torch.nn.functional.cosine_similarity(
vectors1, vectors2, dim=-5)
angles = torch.acos(cos_angles)
fcj1 = _cutoff_cosine(distances1, Rca)
fcj2 = _cutoff_cosine(distances2, Rca)
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA *
((distances1 + distances2) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
@torch.jit.script
def default_neighborlist(species, coordinates, cutoff):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
"""Default neighborlist computer"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
......@@ -124,70 +201,6 @@ class AEVComputer(torch.nn.Module):
"""Returns the length of full aev"""
return self.radial_length() + self.angular_length()
def _radial_subaev_terms(self, distances):
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where ``N``
is the number of neighbor atoms within the cutoff radius and output
tensor should have shape
(conformations, atoms, ``self.radial_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances = distances.unsqueeze(-1).unsqueeze(-1)
fc = _cutoff_cosine(distances, self.Rcr)
# Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-self.EtaR * (distances - self.ShfR)**2) * fc
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-2)
def _angular_subaev_terms(self, vectors1, vectors2):
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where N
is the number of neighbor atom pairs within the cutoff radius and
output tensor should have shape
(conformations, atoms, ``self.angular_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1 = vectors1.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * \
torch.nn.functional.cosine_similarity(
vectors1, vectors2, dim=-5)
angles = torch.acos(cos_angles)
fcj1 = _cutoff_cosine(distances1, self.Rca)
fcj2 = _cutoff_cosine(distances2, self.Rca)
factor1 = ((1 + torch.cos(angles - self.ShfZ)) / 2) ** self.Zeta
factor2 = torch.exp(-self.EtaA *
((distances1 + distances2) / 2 - self.ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
def _terms_and_indices(self, species, coordinates):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
......@@ -197,10 +210,12 @@ class AEVComputer(torch.nn.Module):
max_cutoff = max(self.Rcr, self.Rca)
species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff)
radial_terms = self._radial_subaev_terms(distances)
radial_terms = _radial_subaev_terms(self.Rcr, self.EtaR,
self.ShfR, distances)
vec = self._combinations(vec, -2)
angular_terms = self._angular_subaev_terms(*vec)
angular_terms = _angular_subaev_terms(self.Rca, self.ShfZ, self.EtaA,
self.Zeta, self.ShfA, *vec)
return radial_terms, angular_terms, species_
......
......@@ -15,7 +15,7 @@ import ase.units
import copy
class NeighborList:
class NeighborList(torch.nn.Module):
"""ASE neighborlist computer
Arguments:
......@@ -25,11 +25,12 @@ class NeighborList:
def __init__(self, cell=None, pbc=None):
# wrap `cell` and `pbc` with `ase.Atoms`
super(NeighborList, self).__init__()
a = ase.Atoms('He', [[0, 0, 0]], cell=cell, pbc=pbc)
self.pbc = a.get_pbc()
self.cell = a.get_cell(complete=True)
def __call__(self, species, coordinates, cutoff):
def forward(self, species, coordinates, cutoff):
conformations = species.shape[0]
max_atoms = species.shape[1]
neighbor_species = []
......@@ -136,9 +137,9 @@ class Calculator(ase.calculators.calculator.Calculator):
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
if not self._default_neighborlist:
self.aev_computer.neighborlist = NeighborList(
cell=self.atoms.get_cell(complete=True),
pbc=self.atoms.get_pbc())
self.aev_computer.neighborlist.pbc = self.atoms.get_pbc()
self.aev_computer.neighborlist.cell = \
self.atoms.get_cell(complete=True)
species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment