Unverified Commit a1adceb0 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

[JIT] Add TorchScript Compatibility for ANIModel and Ensemble (#307)

* add test for scripted ensemble

* use torchani.nn.Sequential

* change OrderedDict to module list

* fix

* fix nn.py

* try more fix

* try

* more

* more

* fix more

* rename

* bring ensemble back

* make ANIModel iterable
parent b59551d8
...@@ -65,6 +65,7 @@ class TestEnergies(unittest.TestCase): ...@@ -65,6 +65,7 @@ class TestEnergies(unittest.TestCase):
class TestEnergiesEnergyShifterJIT(TestEnergies): class TestEnergiesEnergyShifterJIT(TestEnergies):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.energy_shifter = torch.jit.script(self.energy_shifter) self.energy_shifter = torch.jit.script(self.energy_shifter)
...@@ -72,5 +73,21 @@ class TestEnergiesEnergyShifterJIT(TestEnergies): ...@@ -72,5 +73,21 @@ class TestEnergiesEnergyShifterJIT(TestEnergies):
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter) self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
class TestEnergiesANIModelJIT(TestEnergies):
def setUp(self):
super().setUp()
self.nnp = torch.jit.script(self.nnp)
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
class TestEnergiesJIT(TestEnergies):
def setUp(self):
super().setUp()
self.model = torch.jit.script(self.model)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -13,16 +13,15 @@ class TestEnsemble(unittest.TestCase): ...@@ -13,16 +13,15 @@ class TestEnsemble(unittest.TestCase):
def setUp(self): def setUp(self):
self.tol = 1e-5 self.tol = 1e-5
self.conformations = 20 self.conformations = 20
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.model_iterator = ani1x.neural_networks
self.ensemble = torchani.nn.Sequential(self.aev_computer, self.model_iterator)
def _test_molecule(self, coordinates, species): def _test_molecule(self, coordinates, species):
ani1x = torchani.models.ANI1x() model_list = [torchani.nn.Sequential(self.aev_computer, m) for m in self.model_iterator]
coordinates.requires_grad_(True) coordinates.requires_grad_(True)
aev = ani1x.aev_computer _, energy1 = self.ensemble((species, coordinates))
model_iterator = ani1x.neural_networks
model_list = [torch.nn.Sequential(aev, m) for m in model_iterator]
ensemble = torch.nn.Sequential(aev, model_iterator)
_, energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0] force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m((species, coordinates))[1] for m in model_list] energy2 = [m((species, coordinates))[1] for m in model_list]
energy2 = sum(energy2) / len(model_list) energy2 = sum(energy2) / len(model_list)
...@@ -42,5 +41,13 @@ class TestEnsemble(unittest.TestCase): ...@@ -42,5 +41,13 @@ class TestEnsemble(unittest.TestCase):
self._test_molecule(coordinates, species) self._test_molecule(coordinates, species)
class TestEnsembleJIT(TestEnsemble):
def setUp(self):
super().setUp()
self.ensemble = torchani.nn.Sequential(self.aev_computer, self.model_iterator)
self.ensemble = torch.jit.script(self.ensemble)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
import torch import torch
from . import utils
from typing import Tuple from typing import Tuple
class ANIModel(torch.nn.ModuleList): class ANIModel(torch.nn.Module):
"""ANI model that compute properties from species and AEVs. """ANI model that compute properties from species and AEVs.
Different atom types might have different modules, when computing Different atom types might have different modules, when computing
...@@ -17,9 +16,6 @@ class ANIModel(torch.nn.ModuleList): ...@@ -17,9 +16,6 @@ class ANIModel(torch.nn.ModuleList):
:attr:`modules`, which means, for example ``modules[i]`` must be :attr:`modules`, which means, for example ``modules[i]`` must be
the module for atom type ``i``. Different atom types can share a the module for atom type ``i``. Different atom types can share a
module by putting the same reference in :attr:`modules`. module by putting the same reference in :attr:`modules`.
reducer (:class:`collections.abc.Callable`): The callable that reduce
atomic outputs into molecular outputs. It must have signature
``(tensor, dim)->tensor``.
padding_fill (float): The value to fill output of padding atoms. padding_fill (float): The value to fill output of padding atoms.
Padding values will participate in reducing, so this value should Padding values will participate in reducing, so this value should
be appropriately chosen so that it has no effect on the result. For be appropriately chosen so that it has no effect on the result. For
...@@ -29,55 +25,88 @@ class ANIModel(torch.nn.ModuleList): ...@@ -29,55 +25,88 @@ class ANIModel(torch.nn.ModuleList):
:obj:`math.inf`. :obj:`math.inf`.
""" """
def __init__(self, modules, reducer=torch.sum, padding_fill=0): def __init__(self, modules, padding_fill=0):
super(ANIModel, self).__init__(modules) super(ANIModel, self).__init__()
self.reducer = reducer self.module_list = torch.nn.ModuleList(modules)
self.padding_fill = padding_fill self.padding_fill = padding_fill
def __getitem__(self, i):
return self.module_list[i]
def forward(self, species_aev): def forward(self, species_aev):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
species, aev = species_aev species, aev = species_aev
species_ = species.flatten() species_ = species.flatten()
present_species = utils.present_species(species)
aev = aev.flatten(0, 1) aev = aev.flatten(0, 1)
output = torch.full_like(species_, self.padding_fill, output = torch.full(species_.shape, self.padding_fill,
dtype=aev.dtype) dtype=aev.dtype)
for i in present_species: i = 0
for m in self.module_list:
mask = (species_ == i) mask = (species_ == i)
input_ = aev.index_select(0, mask.nonzero().squeeze()) i += 1
output.masked_scatter_(mask, self[i](input_).squeeze()) midx = mask.nonzero().flatten()
if midx.shape[0] > 0:
input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species) output = output.view_as(species)
return species, self.reducer(output, dim=1) return species, torch.sum(output, dim=1)
class Ensemble(torch.nn.ModuleList): class Ensemble(torch.nn.Module):
"""Compute the average output of an ensemble of modules.""" """Compute the average output of an ensemble of modules."""
# FIXME: due to PyTorch bug, we have to hard code the
# ensemble size to 8.
# def __init__(self, modules):
# super(Ensemble, self).__init__()
# self.modules_list = torch.nn.ModuleList(modules)
# def forward(self, species_input):
# # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
# outputs = [x(species_input)[1] for x in self.modules_list]
# species, _ = species_input
# return species, sum(outputs) / len(outputs)
def __init__(self, modules):
super(Ensemble, self).__init__()
assert len(modules) == 8
self.model0 = modules[0]
self.model1 = modules[1]
self.model2 = modules[2]
self.model3 = modules[3]
self.model4 = modules[4]
self.model5 = modules[5]
self.model6 = modules[6]
self.model7 = modules[7]
def __getitem__(self, i):
return [self.model0, self.model1, self.model2, self.model3,
self.model4, self.model5, self.model6, self.model7][i]
def forward(self, species_input): def forward(self, species_input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
outputs = [x(species_input)[1] for x in self]
species, _ = species_input species, _ = species_input
return species, sum(outputs) / len(outputs) sum_ = self.model0(species_input)[1] + self.model1(species_input)[1] \
+ self.model2(species_input)[1] + self.model3(species_input)[1] \
+ self.model4(species_input)[1] + self.model5(species_input)[1] \
+ self.model6(species_input)[1] + self.model7(species_input)[1]
return species, sum_ / 8.0
class Sequential(torch.nn.Module): class Sequential(torch.nn.Module):
"""Modified Sequential module that accept Tuple type as input""" """Modified Sequential module that accept Tuple type as input"""
def __init__(self, *args): def __init__(self, *modules):
super(Sequential, self).__init__() super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], torch.OrderedDict): self.modules_list = torch.nn.ModuleList(modules)
for key, module in args[0].items():
self.add_module(key, module) def forward(self, input_):
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor] # type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
for module in self._modules.values(): for module in self.modules_list:
input = module(input) input_ = module(input_)
return input return input_
class Gaussian(torch.nn.Module): class Gaussian(torch.nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment