Unverified Commit f2170e24 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

[JIT] Add TorchScript compatibility for AEVComputer (#303)

* make aev,model compatible with jit

* add type annotation to nn

* flake8 fix

* refactor AEVComputer

* fix doc

* an example with padding

* use Optional type instead of padding

* fix

* fix

* make pbc and cell keyword arguments in test_aev

* fix

* make pbc and cell keyword arguments in ase

* fix

* fix

* fix dtype

* fix

* aev_computer dtype to double

* change test files to have aev_computer with keyword argument

* fix JIT types

* add TestAEVJIT

* fix LGTM alerts

* fix TestAEVJIT

* Update aev.py

workaround for dtype in `torch.arange`

* More arange bugs

* Even more arange

* fix LGTM alert
parent 3957d19c
...@@ -146,7 +146,7 @@ class TestAEV(unittest.TestCase): ...@@ -146,7 +146,7 @@ class TestAEV(unittest.TestCase):
species = self.transform(species) species = self.transform(species)
expected_radial = self.transform(expected_radial) expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular) expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates, cell, pbc)) _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
self.assertAEVEqual(expected_radial, expected_angular, aev, 5e-5) self.assertAEVEqual(expected_radial, expected_angular, aev, 5e-5)
def testTripeptideMD(self): def testTripeptideMD(self):
...@@ -245,6 +245,12 @@ class TestAEV(unittest.TestCase): ...@@ -245,6 +245,12 @@ class TestAEV(unittest.TestCase):
) )
class TestAEVJIT(TestAEV):
def setUp(self):
super().setUp()
self.aev_computer = torch.jit.script(self.aev_computer)
class TestPBCSeeEachOther(unittest.TestCase): class TestPBCSeeEachOther(unittest.TestCase):
def setUp(self): def setUp(self):
self.ani1x = torchani.models.ANI1x() self.ani1x = torchani.models.ANI1x()
...@@ -262,11 +268,11 @@ class TestPBCSeeEachOther(unittest.TestCase): ...@@ -262,11 +268,11 @@ class TestPBCSeeEachOther(unittest.TestCase):
species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long) species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
pbc = torch.ones(3, dtype=torch.bool) pbc = torch.ones(3, dtype=torch.bool)
_, aev = self.aev_computer((species, coordinates, cell, pbc)) _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
for _ in range(100): for _ in range(100):
translation = torch.randn(3, dtype=torch.double) translation = torch.randn(3, dtype=torch.double)
_, aev2 = self.aev_computer((species, coordinates + translation, cell, pbc)) _, aev2 = self.aev_computer((species, coordinates + translation), cell=cell, pbc=pbc)
self.assertTrue(torch.allclose(aev, aev2)) self.assertTrue(torch.allclose(aev, aev2))
def testPBCConnersSeeEachOther(self): def testPBCConnersSeeEachOther(self):
...@@ -363,7 +369,7 @@ class TestAEVOnBoundary(unittest.TestCase): ...@@ -363,7 +369,7 @@ class TestAEVOnBoundary(unittest.TestCase):
self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3) self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer.to(torch.double) self.aev_computer = ani1x.aev_computer.to(torch.double)
_, self.aev = self.aev_computer((self.species, self.center_coordinates, self.cell, self.pbc)) _, self.aev = self.aev_computer((self.species, self.center_coordinates), cell=self.cell, pbc=self.pbc)
def assertInCell(self, coordinates): def assertInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell coordinates_cell = coordinates @ self.inv_cell
...@@ -385,7 +391,7 @@ class TestAEVOnBoundary(unittest.TestCase): ...@@ -385,7 +391,7 @@ class TestAEVOnBoundary(unittest.TestCase):
self.assertNotInCell(coordinates) self.assertNotInCell(coordinates)
coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc) coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
self.assertInCell(coordinates) self.assertInCell(coordinates)
_, aev = self.aev_computer((self.species, coordinates, self.cell, self.pbc)) _, aev = self.aev_computer((self.species, coordinates), cell=self.cell, pbc=self.pbc)
self.assertGreater(aev.abs().max().item(), 0) self.assertGreater(aev.abs().max().item(), 0)
self.assertTrue(torch.allclose(aev, self.aev)) self.assertTrue(torch.allclose(aev, self.aev))
...@@ -402,7 +408,7 @@ class TestAEVOnBenzenePBC(unittest.TestCase): ...@@ -402,7 +408,7 @@ class TestAEVOnBenzenePBC(unittest.TestCase):
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O']) species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0) self.species = species_to_tensor(benzene.get_chemical_symbols()).unsqueeze(0)
self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float() self.coordinates = torch.tensor(benzene.get_positions()).unsqueeze(0).float()
_, self.aev = self.aev_computer((self.species, self.coordinates, self.cell, self.pbc)) _, self.aev = self.aev_computer((self.species, self.coordinates), cell=self.cell, pbc=self.pbc)
self.natoms = self.aev.shape[1] self.natoms = self.aev.shape[1]
def testRepeat(self): def testRepeat(self):
...@@ -416,7 +422,7 @@ class TestAEVOnBenzenePBC(unittest.TestCase): ...@@ -416,7 +422,7 @@ class TestAEVOnBenzenePBC(unittest.TestCase):
self.coordinates + 3 * c1, self.coordinates + 3 * c1,
], dim=1) ], dim=1)
cell2 = torch.stack([4 * c1, c2, c3]) cell2 = torch.stack([4 * c1, c2, c3])
_, aev2 = self.aev_computer((species2, coordinates2, cell2, self.pbc)) _, aev2 = self.aev_computer((species2, coordinates2), cell=cell2, pbc=self.pbc)
for i in range(3): for i in range(3):
aev3 = aev2[:, i * self.natoms: (i + 1) * self.natoms, :] aev3 = aev2[:, i * self.natoms: (i + 1) * self.natoms, :]
self.assertTrue(torch.allclose(self.aev, aev3, atol=tolerance)) self.assertTrue(torch.allclose(self.aev, aev3, atol=tolerance))
......
...@@ -15,10 +15,11 @@ class TestEnergies(unittest.TestCase): ...@@ -15,10 +15,11 @@ class TestEnergies(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 5e-5 self.tolerance = 5e-5
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
aev_computer = ani1x.aev_computer self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0] nnp = ani1x.neural_networks[0]
shift_energy = ani1x.energy_shifter shift_energy = ani1x.energy_shifter
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy) self.nn = torch.nn.Sequential(nnp, shift_energy)
self.model = torch.nn.Sequential(self.aev_computer, nnp, shift_energy)
def random_skip(self): def random_skip(self):
return False return False
...@@ -56,7 +57,8 @@ class TestEnergies(unittest.TestCase): ...@@ -56,7 +57,8 @@ class TestEnergies(unittest.TestCase):
coordinates = self.transform(coordinates) coordinates = self.transform(coordinates)
species = self.transform(species) species = self.transform(species)
energies = self.transform(energies) energies = self.transform(energies)
_, energies_ = self.model((species, coordinates, cell, pbc)) _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
_, energies_ = self.nn((species, aev))
max_diff = (energies - energies_).abs().max().item() max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, tolerance) self.assertLess(max_diff, tolerance)
......
...@@ -14,8 +14,8 @@ class TestForce(unittest.TestCase): ...@@ -14,8 +14,8 @@ class TestForce(unittest.TestCase):
self.tolerance = 1e-5 self.tolerance = 1e-5
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0] self.nnp = ani1x.neural_networks[0]
self.model = torch.nn.Sequential(self.aev_computer, nnp) self.model = torch.nn.Sequential(self.aev_computer, self.nnp)
def random_skip(self): def random_skip(self):
return False return False
...@@ -82,7 +82,8 @@ class TestForce(unittest.TestCase): ...@@ -82,7 +82,8 @@ class TestForce(unittest.TestCase):
coordinates = self.transform(coordinates) coordinates = self.transform(coordinates)
species = self.transform(species) species = self.transform(species)
forces = self.transform(forces) forces = self.transform(forces)
_, energies_ = self.model((species, coordinates, cell, pbc)) _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
_, energies_ = self.nnp((species, aev))
derivative = torch.autograd.grad(energies_.sum(), derivative = torch.autograd.grad(energies_.sum(),
coordinates)[0] coordinates)[0]
max_diff = (forces + derivative).abs().max().item() max_diff = (forces + derivative).abs().max().item()
......
...@@ -2,17 +2,15 @@ from __future__ import division ...@@ -2,17 +2,15 @@ from __future__ import division
import torch import torch
from . import _six # noqa:F401 from . import _six # noqa:F401
import math import math
from typing import Tuple from typing import Tuple, Optional
# @torch.jit.script
def cutoff_cosine(distances, cutoff): def cutoff_cosine(distances, cutoff):
# type: (torch.Tensor, float) -> torch.Tensor # type: (torch.Tensor, float) -> torch.Tensor
# assuming all elements in distances are smaller than cutoff # assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5 return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
# @torch.jit.script
def radial_terms(Rcr, EtaR, ShfR, distances): def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (float, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the radial subAEV terms of the center atom given neighbors """Compute the radial subAEV terms of the center atom given neighbors
...@@ -40,7 +38,6 @@ def radial_terms(Rcr, EtaR, ShfR, distances): ...@@ -40,7 +38,6 @@ def radial_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2) return ret.flatten(start_dim=-2)
# @torch.jit.script
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor # type: (float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor) -> torch.Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs. """Compute the angular subAEV terms of the center atom given neighbor pairs.
...@@ -77,8 +74,8 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): ...@@ -77,8 +74,8 @@ def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4) return ret.flatten(start_dim=-4)
# @torch.jit.script
def compute_shifts(cell, pbc, cutoff): def compute_shifts(cell, pbc, cutoff):
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
"""Compute the shifts of unit cell along the given cell vectors to make it """Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under large enough to contain all pairs of neighbor atoms with PBC under
consideration consideration
...@@ -95,14 +92,13 @@ def compute_shifts(cell, pbc, cutoff): ...@@ -95,14 +92,13 @@ def compute_shifts(cell, pbc, cutoff):
:class:`torch.Tensor`: long tensor of shifts. the center cell and :class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included. symmetric cells are not included.
""" """
# type: (torch.Tensor, torch.Tensor, float) -> torch.Tensor
reciprocal_cell = cell.inverse().t() reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1) inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long) num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats)) num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device) r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device, dtype=torch.long)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device) r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device, dtype=torch.long)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device) r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device, dtype=torch.long)
o = torch.zeros(1, dtype=torch.long, device=cell.device) o = torch.zeros(1, dtype=torch.long, device=cell.device)
return torch.cat([ return torch.cat([
torch.cartesian_prod(r1, r2, r3), torch.cartesian_prod(r1, r2, r3),
...@@ -121,8 +117,8 @@ def compute_shifts(cell, pbc, cutoff): ...@@ -121,8 +117,8 @@ def compute_shifts(cell, pbc, cutoff):
]) ])
# @torch.jit.script
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
"""Compute pairs of atoms that are neighbors """Compute pairs of atoms that are neighbors
Arguments: Arguments:
...@@ -135,21 +131,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): ...@@ -135,21 +131,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
cutoff (float): the cutoff inside which atoms are considered pairs cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
""" """
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
coordinates = coordinates.detach() coordinates = coordinates.detach()
cell = cell.detach() cell = cell.detach()
num_atoms = padding_mask.shape[1] num_atoms = padding_mask.shape[1]
all_atoms = torch.arange(num_atoms, device=cell.device) all_atoms = torch.arange(num_atoms, device=cell.device, dtype=torch.long)
# Step 2: center cell # Step 2: center cell
p1_center, p2_center = torch.combinations(all_atoms).unbind(-1) p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
shifts_center = shifts.new_zeros(p1_center.shape[0], 3) shifts_center = torch.zeros((p1_center.shape[0], 3), dtype=shifts.dtype, device=shifts.device)
# Step 3: cells with shifts # Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3) # shape convention (shift index, molecule index, atom index, 3)
num_shifts = shifts.shape[0] num_shifts = shifts.shape[0]
all_shifts = torch.arange(num_shifts, device=cell.device) all_shifts = torch.arange(num_shifts, device=cell.device, dtype=torch.long)
shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1) shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1)
shifts_outide = shifts.index_select(0, shift_index) shifts_outide = shifts.index_select(0, shift_index)
...@@ -172,19 +166,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff): ...@@ -172,19 +166,19 @@ def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
return molecule_index + atom_index1, molecule_index + atom_index2, shifts return molecule_index + atom_index1, molecule_index + atom_index2, shifts
# torch.jit.script
def triu_index(num_species): def triu_index(num_species):
species = torch.arange(num_species) # type: (int) -> torch.Tensor
species = torch.arange(num_species, dtype=torch.long)
species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1) species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1)
pair_index = torch.arange(species1.shape[0]) pair_index = torch.arange(species1.shape[0], dtype=torch.long)
ret = torch.zeros(num_species, num_species, dtype=torch.long) ret = torch.zeros(num_species, num_species, dtype=torch.long)
ret[species1, species2] = pair_index ret[species1, species2] = pair_index
ret[species2, species1] = pair_index ret[species2, species1] = pair_index
return ret return ret
# torch.jit.script
def convert_pair_index(index): def convert_pair_index(index):
# type: (torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
"""Let's say we have a pair: """Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ... index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ... elem1: 0 0 1 0 1 2 0 1 2 3 ...
...@@ -208,15 +202,15 @@ def convert_pair_index(index): ...@@ -208,15 +202,15 @@ def convert_pair_index(index):
return index - num_elems, n + 1 return index - num_elems, n + 1
# torch.jit.script
def cumsum_from_zero(input_): def cumsum_from_zero(input_):
# type: (torch.Tensor) -> torch.Tensor
cumsum = torch.cumsum(input_, dim=0) cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([input_.new_tensor([0]), cumsum[:-1]]) cumsum = torch.cat([torch.tensor([0], dtype=input_.dtype, device=input_.device), cumsum[:-1]])
return cumsum return cumsum
# torch.jit.script
def triple_by_molecule(atom_index1, atom_index2): def triple_by_molecule(atom_index1, atom_index2):
# type: (torch.Tensor, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
"""Input: indices for pairs of atoms that are close to each other. """Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists. (2, 1) exists.
...@@ -233,16 +227,18 @@ def triple_by_molecule(atom_index1, atom_index2): ...@@ -233,16 +227,18 @@ def triple_by_molecule(atom_index1, atom_index2):
sorted_ai1, rev_indices = ai1.sort() sorted_ai1, rev_indices = ai1.sort()
# sort and compute unique key # sort and compute unique key
uniqued_central_atom_index, counts = torch.unique_consecutive(sorted_ai1, return_counts=True) unique_results = torch.unique_consecutive(sorted_ai1, return_inverse=True, return_counts=True)
uniqued_central_atom_index = unique_results[0]
counts = unique_results[-1]
# do local combinations within unique key, assuming sorted # do local combinations within unique key, assuming sorted
pair_sizes = counts * (counts - 1) // 2 pair_sizes = (counts * (counts - 1) / 2).long()
total_size = pair_sizes.sum() total_size = pair_sizes.sum()
pair_indices = torch.repeat_interleave(pair_sizes) pair_indices = torch.repeat_interleave(pair_sizes)
central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices) central_atom_index = uniqued_central_atom_index.index_select(0, pair_indices)
cumsum = cumsum_from_zero(pair_sizes) cumsum = cumsum_from_zero(pair_sizes)
cumsum = cumsum.index_select(0, pair_indices) cumsum = cumsum.index_select(0, pair_indices)
sorted_local_pair_index = torch.arange(total_size, device=cumsum.device) - cumsum sorted_local_pair_index = torch.arange(total_size, device=cumsum.device, dtype=torch.long) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index) sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts) cumsum = cumsum_from_zero(counts)
cumsum = cumsum.index_select(0, pair_indices) cumsum = cumsum.index_select(0, pair_indices)
...@@ -259,8 +255,8 @@ def triple_by_molecule(atom_index1, atom_index2): ...@@ -259,8 +255,8 @@ def triple_by_molecule(atom_index1, atom_index2):
return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2 return central_atom_index, local_index1 % n, local_index2 % n, sign1, sign2
# torch.jit.script
def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes): def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes):
# type: (torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[float, torch.Tensor, torch.Tensor, float, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], Tuple[int, int, int, int, int, int]) > torch.Tensor
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0] num_molecules = species.shape[0]
...@@ -279,7 +275,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -279,7 +275,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
# compute radial aev # compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances) radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength) radial_aev = torch.zeros((num_molecules * num_atoms * num_species, radial_sublength), dtype=radial_terms_.dtype, device=radial_terms_.device)
index1 = atom_index1 * num_species + species2 index1 = atom_index1 * num_species + species2
index2 = atom_index2 * num_species + species1 index2 = atom_index2 * num_species + species1
radial_aev.index_add_(0, index1, radial_terms_) radial_aev.index_add_(0, index1, radial_terms_)
...@@ -302,7 +298,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes ...@@ -302,7 +298,7 @@ def compute_aev(species, coordinates, cell, shifts, triu_index, constants, sizes
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1]) species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2]) species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2) angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength) angular_aev = torch.zeros((num_molecules * num_atoms * num_species_pairs, angular_sublength), dtype=angular_terms_.dtype, device=angular_terms_.device)
index = central_atom_index * num_species_pairs + triu_index[species1_, species2_] index = central_atom_index * num_species_pairs + triu_index[species1_, species2_]
angular_aev.index_add_(0, index, angular_terms_) angular_aev.index_add_(0, index, angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length) angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
...@@ -380,23 +376,24 @@ class AEVComputer(torch.nn.Module): ...@@ -380,23 +376,24 @@ class AEVComputer(torch.nn.Module):
def constants(self): def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
# @torch.jit.script_method def forward(self, input_, cell=None, pbc=None):
def forward(self, input_): # type: (Tuple[torch.Tensor, torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""Compute AEVs """Compute AEVs
Arguments: Arguments:
input_ (tuple): Can be one of the following two cases: input_ (tuple): Can be one of the following two cases:
If you don't care about periodic boundary conditions at all, If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species and coordinates. then input can be a tuple of two tensors: species, coordinates.
species must have shape ``(C, A)`` and coordinates must have species must have shape ``(C, A)``, coordinates must have shape
shape ``(C, A, 3)``, where ``C`` is the number of molecules ``(C, A, 3)`` where ``C`` is the number of molecules in a chunk,
in a chunk, and ``A`` is the number of atoms. and ``A`` is the number of atoms.
If you want to apply periodic boundary conditions, then the input If you want to apply periodic boundary conditions, then the input
would be a tuple of four tensors: species, coordinates, cell, pbc would be a tuple of two tensors (species, coordinates) and two keyword
where species and coordinates are the same as described above, cell arguments `cell=...` , and `pbc=...` where species and coordinates are
is a tensor of shape (3, 3) of the three vectors defining unit cell: the same as described above, cell is a tensor of shape (3, 3) of the
three vectors defining unit cell:
.. code-block:: python .. code-block:: python
...@@ -412,13 +409,14 @@ class AEVComputer(torch.nn.Module): ...@@ -412,13 +409,14 @@ class AEVComputer(torch.nn.Module):
unchanged, and AEVs is a tensor of shape unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())`` ``(C, A, self.aev_length())``
""" """
if len(input_) == 2: species, coordinates = input_
species, coordinates = input_
if cell is None and pbc is None:
cell = self.default_cell cell = self.default_cell
shifts = self.default_shifts shifts = self.default_shifts
else: else:
assert len(input_) == 4 assert (cell is not None and pbc is not None)
species, coordinates, cell, pbc = input_
cutoff = max(self.Rcr, self.Rca) cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff) shifts = compute_shifts(cell, pbc, cutoff)
return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes) return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
...@@ -38,7 +38,8 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -38,7 +38,8 @@ class Calculator(ase.calculators.calculator.Calculator):
self.species_to_tensor = utils.ChemicalSymbolsToInts(species) self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to # aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object # make sure we do not change the original object
self.aev_computer = copy.deepcopy(aev_computer) aev_computer = copy.deepcopy(aev_computer)
self.aev_computer = aev_computer.to(dtype)
self.model = copy.deepcopy(model) self.model = copy.deepcopy(model)
self.energy_shifter = copy.deepcopy(energy_shifter) self.energy_shifter = copy.deepcopy(energy_shifter)
self.overwrite = overwrite self.overwrite = overwrite
...@@ -46,8 +47,7 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -46,8 +47,7 @@ class Calculator(ase.calculators.calculator.Calculator):
self.device = self.aev_computer.EtaR.device self.device = self.aev_computer.EtaR.device
self.dtype = dtype self.dtype = dtype
self.whole = torch.nn.Sequential( self.nn = torch.nn.Sequential(
self.aev_computer,
self.model, self.model,
self.energy_shifter self.energy_shifter
).to(dtype) ).to(dtype)
...@@ -93,9 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -93,9 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator):
strain_y = self.strain(cell, displacement_y, 1) strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2) strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z cell = cell + strain_x + strain_y + strain_z
_, energy = self.whole((species, coordinates, cell, pbc)) _, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
else: else:
_, energy = self.whole((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
_, energy = self.nn((species, aev))
energy *= ase.units.Hartree energy *= ase.units.Hartree
self.results['energy'] = energy.item() self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item() self.results['free_energy'] = energy.item()
......
...@@ -28,6 +28,7 @@ shouldn't be used anymore. ...@@ -28,6 +28,7 @@ shouldn't be used anymore.
""" """
import torch import torch
from typing import Tuple
from pkg_resources import resource_filename from pkg_resources import resource_filename
from . import neurochem from . import neurochem
from .aev import AEVComputer from .aev import AEVComputer
...@@ -87,6 +88,7 @@ class BuiltinNet(torch.nn.Module): ...@@ -87,6 +88,7 @@ class BuiltinNet(torch.nn.Module):
self.species, self.ensemble_prefix, self.ensemble_size) self.species, self.ensemble_prefix, self.ensemble_size)
def forward(self, species_coordinates): def forward(self, species_coordinates):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""Calculates predicted properties for minibatch of configurations """Calculates predicted properties for minibatch of configurations
Args: Args:
......
import torch import torch
from . import utils from . import utils
from typing import Tuple
class ANIModel(torch.nn.ModuleList): class ANIModel(torch.nn.ModuleList):
...@@ -34,6 +35,7 @@ class ANIModel(torch.nn.ModuleList): ...@@ -34,6 +35,7 @@ class ANIModel(torch.nn.ModuleList):
self.padding_fill = padding_fill self.padding_fill = padding_fill
def forward(self, species_aev): def forward(self, species_aev):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
species, aev = species_aev species, aev = species_aev
species_ = species.flatten() species_ = species.flatten()
present_species = utils.present_species(species) present_species = utils.present_species(species)
...@@ -53,6 +55,7 @@ class Ensemble(torch.nn.ModuleList): ...@@ -53,6 +55,7 @@ class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemble of modules.""" """Compute the average output of an ensemble of modules."""
def forward(self, species_input): def forward(self, species_input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
outputs = [x(species_input)[1] for x in self] outputs = [x(species_input)[1] for x in self]
species, _ = species_input species, _ = species_input
return species, sum(outputs) / len(outputs) return species, sum(outputs) / len(outputs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment