Unverified Commit f2170e24 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

[JIT] Add TorchScript compatibility for AEVComputer (#303)

* make aev,model compatible with jit

* add type annotation to nn

* flake8 fix

* refactor AEVComputer

* fix doc

* an example with padding

* use Optional type instead of padding

* fix

* fix

* make pbc and cell keyword arguments in test_aev

* fix

* make pbc and cell keyword arguments in ase

* fix

* fix

* fix dtype

* fix

* aev_computer dtype to double

* change test files to have aev_computer with keyword argument

* fix JIT types

* add TestAEVJIT

* fix LGTM alerts

* fix TestAEVJIT

* Update aev.py

workaround for dtype in `torch.arange`

* More arange bugs

* Even more arange

* fix LGTM alert
parent 3957d19c
......@@ -146,7 +146,7 @@ class TestAEV(unittest.TestCase):
species = self.transform(species)
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
self.assertAEVEqual(expected_radial, expected_angular, aev, 5e-5)
def testTripeptideMD(self):
......@@ -245,6 +245,12 @@ class TestAEV(unittest.TestCase):
)
class TestAEVJIT(TestAEV):
def setUp(self):
super().setUp()
self.aev_computer = torch.jit.script(self.aev_computer)
class TestPBCSeeEachOther(unittest.TestCase):
def setUp(self):
self.ani1x = torchani.models.ANI1x()
......@@ -262,11 +268,11 @@ class TestPBCSeeEachOther(unittest.TestCase):
species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
pbc = torch.ones(3, dtype=torch.bool)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
for _ in range(100):
translation = torch.randn(3, dtype=torch.double)
_, aev2 = self.aev_computer((species, coordinates + translation, cell, pbc))
_, aev2 = self.aev_computer((species, coordinates + translation), cell=cell, pbc=pbc)
self.assertTrue(torch.allclose(aev, aev2))
def testPBCConnersSeeEachOther(self):
......@@ -363,7 +369,7 @@ class TestAEVOnBoundary(unittest.TestCase):
self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer.to(torch.double)
_, self.aev = self.aev_computer((self.species, self.center_coordinates, self.cell, self.pbc))
_, self.aev = self.aev_computer((self.species, self.center_coordinates), cell=self.cell, pbc=self.pbc)
def assertInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
......@@ -385,7 +391,7 @@ class TestAEVOnBoundary(unittest.TestCase):
self.assertNotInCell(coordinates)
coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
self.assertInCell(coordinates)
_, aev = self.aev_computer((self.species, coordinates, self.cell, self.pbc))
_, aev = self.aev_computer((self.species, coordinates), cell=self.cell, pbc=self.pbc)
self.assertGreater(aev.abs().max().item(), 0)
self.assertTrue(torch.allclose(aev, self.aev))
......@@ -402,7 +408,7 @@ class TestAEVOnBenzenePBC(unittest.TestCase):
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0)
self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float()
_, self.aev = self.aev_computer((self.species, self.coordinates, self.cell, self.pbc))
_, self.aev = self.aev_computer((self.species, self.coordinates), cell=self.cell, pbc=self.pbc)
self.natoms = self.aev.shape[1]
def testRepeat(self):
......@@ -416,7 +422,7 @@ class TestAEVOnBenzenePBC(unittest.TestCase):
self.coordinates + 3 * c1,
], dim=1)
cell2 = torch.stack([4 * c1, c2, c3])
_, aev2 = self.aev_computer((species2, coordinates2, cell2, self.pbc))
_, aev2 = self.aev_computer((species2, coordinates2), cell=cell2, pbc=self.pbc)
for i in range(3):
aev3 = aev2[:, i * self.natoms: (i + 1) * self.natoms, :]
self.assertTrue(torch.allclose(self.aev, aev3, atol=tolerance))
......
......@@ -15,10 +15,11 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
ani1x = torchani.models.ANI1x()
aev_computer = ani1x.aev_computer
self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0]
shift_energy = ani1x.energy_shifter
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
self.nn = torch.nn.Sequential(nnp, shift_energy)
self.model = torch.nn.Sequential(self.aev_computer, nnp, shift_energy)
def random_skip(self):
return False
......@@ -56,7 +57,8 @@ class TestEnergies(unittest.TestCase):
coordinates = self.transform(coordinates)
species = self.transform(species)
energies = self.transform(energies)
_, energies_ = self.model((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
_, energies_ = self.nn((species, aev))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, tolerance)
......
......@@ -14,8 +14,8 @@ class TestForce(unittest.TestCase):
self.tolerance = 1e-5
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0]
self.model = torch.nn.Sequential(self.aev_computer, nnp)
self.nnp = ani1x.neural_networks[0]
self.model = torch.nn.Sequential(self.aev_computer, self.nnp)
def random_skip(self):
return False
......@@ -82,7 +82,8 @@ class TestForce(unittest.TestCase):
coordinates = self.transform(coordinates)
species = self.transform(species)
forces = self.transform(forces)
_, energies_ = self.model((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
_, energies_ = self.nnp((species, aev))
derivative = torch.autograd.grad(energies_.sum(),
coordinates)[0]
max_diff = (forces + derivative).abs().max().item()
......
......@@ -2,17 +2,15 @@ from __future__ import division
import torch
from . import _six # noqa:F401
import math
from typing import Tuple
from typing import Tuple, Optional
# @torch.jit.script
def cutoff_cosine(distances, cutoff):
# type: (torch.Tensor, float) -> torch.Tensor
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
# @torch.jit.script
def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
......@@ -40,7 +38,6 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2)
# @torch.jit.script
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
......@@ -77,8 +74,8 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4)
# @torch.jit.script
def compute_shifts(cell, pbc, cutoff):
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
"""Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under
consideration
......@@ -95,14 +92,13 @@ def compute_shifts(cell, pbc, cutoff):
:class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included.
"""
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device, dtype=torch.long)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device, dtype=torch.long)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device, dtype=torch.long)
o = torch.zeros(1, dtype=torch.long, device=cell.device)
return torch.cat([
torch.cartesian_prod(r1, r2, r3),
......@@ -121,8 +117,8 @@ def compute_shifts(cell, pbc, cutoff):
])
# @torch.jit.script
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""Compute pairs of atoms that are neighbors
Arguments:
......@@ -135,21 +131,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
"""
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
coordinates = coordinates.detach()
cell = cell.detach()
num_atoms = padding_mask.shape[1]
all_atoms = torch.arange(num_atoms, device=cell.device)
all_atoms = torch.arange(num_atoms, device=cell.device, dtype=torch.long)
# Step 2: center cell
p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
shifts_center = shifts.new_zeros(p1_center.shape[0], 3)
shifts_center = torch.zeros((p1_center.shape[0], 3), dtype=shifts.dtype, device=shifts.device)
# Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3)
num_shifts = shifts.shape[0]
all_shifts = torch.arange(num_shifts, device=cell.device)
all_shifts = torch.arange(num_shifts, device=cell.device, dtype=torch.long)
shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1)
shifts_outide = shifts.index_select(0, shift_index)
......@@ -172,19 +166,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
return molecule_index + atom_index1, molecule_index + atom_index2, shifts
# torch.jit.script
def triu_index(num_species):
species = torch.arange(num_species)
# type: (int) -> torch.Tensor
species = torch.arange(num_species, dtype=torch.long)
species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1)
pair_index = torch.arange(species1.shape[0])
pair_index = torch.arange(species1.shape[0], dtype=torch.long)
ret = torch.zeros(num_species, num_species, dtype=torch.long)
ret[species1, species2] = pair_index
ret[species2, species1] = pair_index
return ret
# torch.jit.script
def convert_pair_index(index):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
......@@ -208,15 +202,15 @@ def convert_pair_index(index):
return index - num_elems, n + 1
# torch.jit.script
def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([input_.new_tensor([0]), cumsum[:-1]])
cumsum = torch.cat([torch.tensor([0], dtype=input_.dtype, device=input_.device), cumsum[:-1]])
return cumsum
# torch.jit.script
def triple_by_molecule(atom_index1, atom_index2):
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
......@@ -233,16 +227,18 @@ def triple_by_molecule(atom_index1, atom_index2):
sorted_ai1, rev_indices = ai1.sort()
# sort and compute unique key
uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_counts=True)
unique_results = torch.unique_consecutive(sorted_ai1, return_inverse=True, return_counts=True)
uniqued_central_atom_index = unique_results[0]
counts = unique_results[-1]
# do local combinations within unique key, assuming sorted
pair_sizes = counts * (counts - 1) // 2
pair_sizes = (counts * (counts - 1) / 2).long()
total_size = pair_sizes.sum()
pair_indices = torch.repeat_interleave(pair_sizes)
central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)
cumsum = cumsum_from_zero(pair_sizes)
cumsum = cumsum.index_select(0, pair_indices)
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device) - cumsum
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device, dtype=torch.long) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts)
cumsum = cumsum.index_select(0, pair_indices)
......@@ -259,8 +255,8 @@ def triple_by_molecule(atom_index1, atom_index2):
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
# torch.jit.script
def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[float, torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[int, int, int, int, int, int]) > torch.Tensor
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0]
......@@ -279,7 +275,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
# compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength)
radial_aev = torch.zeros((num_molecules * num_atoms * num_species, radial_sublength), dtype=radial_terms_.dtype, device=radial_terms_.device)
index1 = atom_index1 * num_species + species2
index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_)
......@@ -302,7 +298,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength)
angular_aev = torch.zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength), dtype=angular_terms_.dtype, device=angular_terms_.device)
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
......@@ -380,23 +376,24 @@ class AEVComputer(torch.nn.Module):
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
# @torch.jit.script_method
def forward(self, input_):
def forward(self, input_, cell=None, pbc=None):
# type: (Tuple[torch.Tensor, torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""Compute AEVs
Arguments:
input_ (tuple): Can be one of the following two cases:
If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species and coordinates.
species must have shape ``(C, A)`` and coordinates must have
shape ``(C, A, 3)``, where ``C`` is the number of molecules
in a chunk, and ``A`` is the number of atoms.
then input can be a tuple of two tensors: species, coordinates.
species must have shape ``(C, A)``, coordinates must have shape
``(C, A, 3)`` where ``C`` is the number of molecules in a chunk,
and ``A`` is the number of atoms.
If you want to apply periodic boundary conditions, then the input
would be a tuple of four tensors: species, coordinates, cell, pbc
where species and coordinates are the same as described above, cell
is a tensor of shape (3, 3) of the three vectors defining unit cell:
would be a tuple of two tensors (species, coordinates) and two keyword
arguments `cell=...` , and `pbc=...` where species and coordinates are
the same as described above, cell is a tensor of shape (3, 3) of the
three vectors defining unit cell:
.. code-block:: python
......@@ -412,13 +409,14 @@ class AEVComputer(torch.nn.Module):
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
if len(input_) == 2:
species, coordinates = input_
species, coordinates = input_
if cell is None and pbc is None:
cell = self.default_cell
shifts = self.default_shifts
else:
assert len(input_) == 4
species, coordinates, cell, pbc = input_
assert (cell is not None and pbc is not None)
cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff)
return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
......@@ -38,7 +38,8 @@ class Calculator(ase.calculators.calculator.Calculator):
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
self.aev_computer = copy.deepcopy(aev_computer)
aev_computer = copy.deepcopy(aev_computer)
self.aev_computer = aev_computer.to(dtype)
self.model = copy.deepcopy(model)
self.energy_shifter = copy.deepcopy(energy_shifter)
self.overwrite = overwrite
......@@ -46,8 +47,7 @@ class Calculator(ase.calculators.calculator.Calculator):
self.device = self.aev_computer.EtaR.device
self.dtype = dtype
self.whole = torch.nn.Sequential(
self.aev_computer,
self.nn = torch.nn.Sequential(
self.model,
self.energy_shifter
).to(dtype)
......@@ -93,9 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator):
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
_, energy = self.whole((species, coordinates, cell, pbc))
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
else:
_, energy = self.whole((species, coordinates))
_, aev = self.aev_computer((species, coordinates))
_, energy = self.nn((species, aev))
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
......
......@@ -28,6 +28,7 @@ shouldn't be used anymore.
"""
import torch
from typing import Tuple
from pkg_resources import resource_filename
from . import neurochem
from .aev import AEVComputer
......@@ -87,6 +88,7 @@ class BuiltinNet(torch.nn.Module):
self.species, self.ensemble_prefix, self.ensemble_size)
def forward(self, species_coordinates):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""Calculates predicted properties for minibatch of configurations
Args:
......
import torch
from . import utils
from typing import Tuple
class ANIModel(torch.nn.ModuleList):
......@@ -34,6 +35,7 @@ class ANIModel(torch.nn.ModuleList):
self.padding_fill = padding_fill
def forward(self, species_aev):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
species, aev = species_aev
species_ = species.flatten()
present_species = utils.present_species(species)
......@@ -53,6 +55,7 @@ class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemble of modules."""
def forward(self, species_input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
outputs = [x(species_input)[1] for x in self]
species, _ = species_input
return species, sum(outputs) / len(outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment