Unverified Commit 65bcbb45 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Enable jit on some functions (#178)

parent e4994ad5
from setuptools import setup, find_packages from setuptools import setup, find_packages
import sys
setup_attrs = { setup_attrs = {
'name': 'torchani', 'name': 'torchani',
...@@ -27,4 +28,7 @@ setup_attrs = { ...@@ -27,4 +28,7 @@ setup_attrs = {
], ],
} }
if sys.version_info[0] < 3:
setup_attrs['install_requires'].append('typing')
setup(**setup_attrs) setup(**setup_attrs)
...@@ -92,10 +92,10 @@ def time_func(key, func): ...@@ -92,10 +92,10 @@ def time_func(key, func):
# enable timers # enable timers
nnp[0]._radial_subaev_terms = time_func('radial terms', torchani.aev._radial_subaev_terms = time_func(
nnp[0]._radial_subaev_terms) 'radial terms', torchani.aev._radial_subaev_terms)
nnp[0]._angular_subaev_terms = time_func('angular terms', torchani.aev._angular_subaev_terms = time_func(
nnp[0]._angular_subaev_terms) 'angular terms', torchani.aev._angular_subaev_terms)
nnp[0]._terms_and_indices = time_func('terms and indices', nnp[0]._terms_and_indices = time_func('terms and indices',
nnp[0]._terms_and_indices) nnp[0]._terms_and_indices)
nnp[0]._combinations = time_func('combinations', nnp[0]._combinations) nnp[0]._combinations = time_func('combinations', nnp[0]._combinations)
......
from __future__ import division
import torch import torch
from . import _six # noqa:F401 from . import _six # noqa:F401
import math import math
from . import utils from . import utils
from torch import Tensor
from typing import Tuple
@torch.jit.script
def _cutoff_cosine(distances, cutoff): def _cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor
return torch.where( return torch.where(
distances <= cutoff, distances <= cutoff,
0.5 * torch.cos(math.pi * distances / cutoff) + 0.5, 0.5 * torch.cos(math.pi * distances / cutoff) + 0.5,
...@@ -12,7 +17,79 @@ def _cutoff_cosine(distances, cutoff): ...@@ -12,7 +17,79 @@ def _cutoff_cosine(distances, cutoff):
) )
@torch.jit.script
def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
# type: (float, Tensor, Tensor, Tensor) -> Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where ``N``
is the number of neighbor atoms within the cutoff radius and output
tensor should have shape
(conformations, atoms, ``self.radial_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances = distances.unsqueeze(-1).unsqueeze(-1)
fc = _cutoff_cosine(distances, Rcr)
# Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-EtaR * (distances - ShfR)**2) * fc
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-2)
@torch.jit.script
def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where N
is the number of neighbor atom pairs within the cutoff radius and
output tensor should have shape
(conformations, atoms, ``self.angular_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1 = vectors1.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * \
torch.nn.functional.cosine_similarity(
vectors1, vectors2, dim=-5)
angles = torch.acos(cos_angles)
fcj1 = _cutoff_cosine(distances1, Rca)
fcj2 = _cutoff_cosine(distances2, Rca)
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA *
((distances1 + distances2) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
@torch.jit.script
def default_neighborlist(species, coordinates, cutoff): def default_neighborlist(species, coordinates, cutoff):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
"""Default neighborlist computer""" """Default neighborlist computer"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1) vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
...@@ -124,70 +201,6 @@ class AEVComputer(torch.nn.Module): ...@@ -124,70 +201,6 @@ class AEVComputer(torch.nn.Module):
"""Returns the length of full aev""" """Returns the length of full aev"""
return self.radial_length() + self.angular_length() return self.radial_length() + self.angular_length()
def _radial_subaev_terms(self, distances):
"""Compute the radial subAEV terms of the center atom given neighbors
This correspond to equation (3) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where ``N``
is the number of neighbor atoms within the cutoff radius and output
tensor should have shape
(conformations, atoms, ``self.radial_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances = distances.unsqueeze(-1).unsqueeze(-1)
fc = _cutoff_cosine(distances, self.Rcr)
# Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-self.EtaR * (distances - self.ShfR)**2) * fc
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-2)
def _angular_subaev_terms(self, vectors1, vectors2):
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
This correspond to equation (4) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where N
is the number of neighbor atom pairs within the cutoff radius and
output tensor should have shape
(conformations, atoms, ``self.angular_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1 = vectors1.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * \
torch.nn.functional.cosine_similarity(
vectors1, vectors2, dim=-5)
angles = torch.acos(cos_angles)
fcj1 = _cutoff_cosine(distances1, self.Rca)
fcj2 = _cutoff_cosine(distances2, self.Rca)
factor1 = ((1 + torch.cos(angles - self.ShfZ)) / 2) ** self.Zeta
factor2 = torch.exp(-self.EtaA *
((distances1 + distances2) / 2 - self.ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
def _terms_and_indices(self, species, coordinates): def _terms_and_indices(self, species, coordinates):
"""Returns radial and angular subAEV terms, these terms will be sorted """Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within according to their distances to central atoms, and only these within
...@@ -197,10 +210,12 @@ class AEVComputer(torch.nn.Module): ...@@ -197,10 +210,12 @@ class AEVComputer(torch.nn.Module):
max_cutoff = max(self.Rcr, self.Rca) max_cutoff = max(self.Rcr, self.Rca)
species_, distances, vec = self.neighborlist(species, coordinates, species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff) max_cutoff)
radial_terms = self._radial_subaev_terms(distances) radial_terms = _radial_subaev_terms(self.Rcr, self.EtaR,
self.ShfR, distances)
vec = self._combinations(vec, -2) vec = self._combinations(vec, -2)
angular_terms = self._angular_subaev_terms(*vec) angular_terms = _angular_subaev_terms(self.Rca, self.ShfZ, self.EtaA,
self.Zeta, self.ShfA, *vec)
return radial_terms, angular_terms, species_ return radial_terms, angular_terms, species_
......
...@@ -15,7 +15,7 @@ import ase.units ...@@ -15,7 +15,7 @@ import ase.units
import copy import copy
class NeighborList: class NeighborList(torch.nn.Module):
"""ASE neighborlist computer """ASE neighborlist computer
Arguments: Arguments:
...@@ -25,11 +25,12 @@ class NeighborList: ...@@ -25,11 +25,12 @@ class NeighborList:
def __init__(self, cell=None, pbc=None): def __init__(self, cell=None, pbc=None):
# wrap `cell` and `pbc` with `ase.Atoms` # wrap `cell` and `pbc` with `ase.Atoms`
super(NeighborList, self).__init__()
a = ase.Atoms('He', [[0, 0, 0]], cell=cell, pbc=pbc) a = ase.Atoms('He', [[0, 0, 0]], cell=cell, pbc=pbc)
self.pbc = a.get_pbc() self.pbc = a.get_pbc()
self.cell = a.get_cell(complete=True) self.cell = a.get_cell(complete=True)
def __call__(self, species, coordinates, cutoff): def forward(self, species, coordinates, cutoff):
conformations = species.shape[0] conformations = species.shape[0]
max_atoms = species.shape[1] max_atoms = species.shape[1]
neighbor_species = [] neighbor_species = []
...@@ -136,9 +137,9 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -136,9 +137,9 @@ class Calculator(ase.calculators.calculator.Calculator):
system_changes=ase.calculators.calculator.all_changes): system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes) super(Calculator, self).calculate(atoms, properties, system_changes)
if not self._default_neighborlist: if not self._default_neighborlist:
self.aev_computer.neighborlist = NeighborList( self.aev_computer.neighborlist.pbc = self.atoms.get_pbc()
cell=self.atoms.get_cell(complete=True), self.aev_computer.neighborlist.cell = \
pbc=self.atoms.get_pbc()) self.atoms.get_cell(complete=True)
species = self.species_to_tensor(self.atoms.get_chemical_symbols()) species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = species.unsqueeze(0) species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions()) coordinates = torch.tensor(self.atoms.get_positions())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment