Unverified Commit a1adceb0 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

[JIT] Add TorchScript Compatibility for ANIModel and Ensemble (#307)

* add test for scripted ensemble

* use torchani.nn.Sequential

* change OrderedDict to module list

* fix

* fix nn.py

* try more fix

* try

* more

* more

* fix more

* rename

* bring ensemble back

* make ANIModel iterable
parent b59551d8
......@@ -65,6 +65,7 @@ class TestEnergies(unittest.TestCase):
class TestEnergiesEnergyShifterJIT(TestEnergies):
def setUp(self):
super().setUp()
self.energy_shifter = torch.jit.script(self.energy_shifter)
......@@ -72,5 +73,21 @@ class TestEnergiesEnergyShifterJIT(TestEnergies):
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
class TestEnergiesANIModelJIT(TestEnergies):
def setUp(self):
super().setUp()
self.nnp = torch.jit.script(self.nnp)
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
class TestEnergiesJIT(TestEnergies):
def setUp(self):
super().setUp()
self.model = torch.jit.script(self.model)
if __name__ == '__main__':
unittest.main()
......@@ -13,16 +13,15 @@ class TestEnsemble(unittest.TestCase):
def setUp(self):
self.tol = 1e-5
self.conformations = 20
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.model_iterator = ani1x.neural_networks
self.ensemble = torchani.nn.Sequential(self.aev_computer, self.model_iterator)
def _test_molecule(self, coordinates, species):
ani1x = torchani.models.ANI1x()
model_list = [torchani.nn.Sequential(self.aev_computer, m) for m in self.model_iterator]
coordinates.requires_grad_(True)
aev = ani1x.aev_computer
model_iterator = ani1x.neural_networks
model_list = [torch.nn.Sequential(aev, m) for m in model_iterator]
ensemble = torch.nn.Sequential(aev, model_iterator)
_, energy1 = ensemble((species, coordinates))
_, energy1 = self.ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m((species, coordinates))[1] for m in model_list]
energy2 = sum(energy2) / len(model_list)
......@@ -42,5 +41,13 @@ class TestEnsemble(unittest.TestCase):
self._test_molecule(coordinates, species)
class TestEnsembleJIT(TestEnsemble):
def setUp(self):
super().setUp()
self.ensemble = torchani.nn.Sequential(self.aev_computer, self.model_iterator)
self.ensemble = torch.jit.script(self.ensemble)
if __name__ == '__main__':
unittest.main()
import torch
from . import utils
from typing import Tuple
class ANIModel(torch.nn.ModuleList):
class ANIModel(torch.nn.Module):
"""ANI model that compute properties from species and AEVs.
Different atom types might have different modules, when computing
......@@ -17,9 +16,6 @@ class ANIModel(torch.nn.ModuleList):
:attr:`modules`, which means, for example ``modules[i]`` must be
the module for atom type ``i``. Different atom types can share a
module by putting the same reference in :attr:`modules`.
reducer (:class:`collections.abc.Callable`): The callable that reduce
atomic outputs into molecular outputs. It must have signature
``(tensor, dim)->tensor``.
padding_fill (float): The value to fill output of padding atoms.
Padding values will participate in reducing, so this value should
be appropriately chosen so that it has no effect on the result. For
......@@ -29,55 +25,88 @@ class ANIModel(torch.nn.ModuleList):
:obj:`math.inf`.
"""
def __init__(self, modules, reducer=torch.sum, padding_fill=0):
super(ANIModel, self).__init__(modules)
self.reducer = reducer
def __init__(self, modules, padding_fill=0):
super(ANIModel, self).__init__()
self.module_list = torch.nn.ModuleList(modules)
self.padding_fill = padding_fill
def __getitem__(self, i):
return self.module_list[i]
def forward(self, species_aev):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
species, aev = species_aev
species_ = species.flatten()
present_species = utils.present_species(species)
aev = aev.flatten(0, 1)
output = torch.full_like(species_, self.padding_fill,
output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype)
for i in present_species:
i = 0
for m in self.module_list:
mask = (species_ == i)
input_ = aev.index_select(0, mask.nonzero().squeeze())
output.masked_scatter_(mask, self[i](input_).squeeze())
i += 1
midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species)
return species, self.reducer(output, dim=1)
return species, torch.sum(output, dim=1)
class Ensemble(torch.nn.ModuleList):
class Ensemble(torch.nn.Module):
"""Compute the average output of an ensemble of modules."""
# FIXME: due to PyTorch bug, we have to hard code the
# ensemble size to 8.
# def __init__(self, modules):
# super(Ensemble, self).__init__()
# self.modules_list = torch.nn.ModuleList(modules)
# def forward(self, species_input):
# # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
# outputs = [x(species_input)[1] for x in self.modules_list]
# species, _ = species_input
# return species, sum(outputs) / len(outputs)
def __init__(self, modules):
super(Ensemble, self).__init__()
assert len(modules) == 8
self.model0 = modules[0]
self.model1 = modules[1]
self.model2 = modules[2]
self.model3 = modules[3]
self.model4 = modules[4]
self.model5 = modules[5]
self.model6 = modules[6]
self.model7 = modules[7]
def __getitem__(self, i):
return [self.model0, self.model1, self.model2, self.model3,
self.model4, self.model5, self.model6, self.model7][i]
def forward(self, species_input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
outputs = [x(species_input)[1] for x in self]
species, _ = species_input
return species, sum(outputs) / len(outputs)
sum_ = self.model0(species_input)[1] + self.model1(species_input)[1] \
+ self.model2(species_input)[1] + self.model3(species_input)[1] \
+ self.model4(species_input)[1] + self.model5(species_input)[1] \
+ self.model6(species_input)[1] + self.model7(species_input)[1]
return species, sum_ / 8.0
class Sequential(torch.nn.Module):
"""Modified Sequential module that accept Tuple type as input"""
def __init__(self, *args):
def __init__(self, *modules):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], torch.OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
self.modules_list = torch.nn.ModuleList(modules)
def forward(self, input_):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
for module in self._modules.values():
input = module(input)
return input
for module in self.modules_list:
input_ = module(input_)
return input_
class Gaussian(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment