Unverified Commit 465faa00 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

Use a modified Sequential class with type annotation (#308)

parent 9715749f
......@@ -48,8 +48,8 @@ model = torchani.neurochem.load_model(consts.species, model_dir)
# (Coordinates) -[AEVComputer]-> (AEV) -[Neural Network]->
# (Raw energies) -[EnergyShifter]-> (Final energies)
# From using either the ensemble or a single model:
nnp1 = torch.nn.Sequential(aev_computer, ensemble, energy_shifter)
nnp2 = torch.nn.Sequential(aev_computer, model, energy_shifter)
nnp1 = torchani.nn.Sequential(aev_computer, ensemble, energy_shifter)
nnp2 = torchani.nn.Sequential(aev_computer, model, energy_shifter)
print(nnp1)
print(nnp2)
......
......@@ -180,7 +180,7 @@ nn.apply(init_params)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device)
model = torchani.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
......
......@@ -143,7 +143,7 @@ nn.apply(init_params)
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device)
model = torchani.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Here we will use Adam with weight decay for the weights and Stochastic Gradient
......
......@@ -93,7 +93,7 @@ else:
###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device)
model = torchani.nn.Sequential(aev_computer, nn).to(device)
###############################################################################
# Now setup tensorboard
......
......@@ -18,8 +18,8 @@ class TestEnergies(unittest.TestCase):
self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0]
self.energy_shifter = ani1x.energy_shifter
self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
def random_skip(self):
return False
......@@ -120,8 +120,8 @@ class TestEnergiesEnergyShifterJIT(TestEnergies):
def setUp(self):
super().setUp()
self.energy_shifter = torch.jit.script(self.energy_shifter)
self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
if __name__ == '__main__':
......
......@@ -15,7 +15,7 @@ class TestForce(unittest.TestCase):
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0]
self.model = torch.nn.Sequential(self.aev_computer, self.nnp)
self.model = torchani.nn.Sequential(self.aev_computer, self.nnp)
def random_skip(self):
return False
......
......@@ -30,7 +30,7 @@ class TestIgnite(unittest.TestCase):
def forward(self, x):
return x[0], x[1].flatten()
model = torch.nn.Sequential(aev_computer, nnp, Flatten())
model = torchani.nn.Sequential(aev_computer, nnp, Flatten())
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters())
loss = torchani.ignite.TransformedLoss(
......
......@@ -7,6 +7,7 @@
from __future__ import absolute_import
import torch
from .nn import Sequential
import ase.neighborlist
from . import utils
import ase.calculators.calculator
......@@ -47,7 +48,7 @@ class Calculator(ase.calculators.calculator.Calculator):
self.device = self.aev_computer.EtaR.device
self.dtype = dtype
self.nn = torch.nn.Sequential(
self.nn = Sequential(
self.model,
self.energy_shifter
).to(dtype)
......
......@@ -31,6 +31,7 @@ import torch
from typing import Tuple
from pkg_resources import resource_filename
from . import neurochem
from .nn import Sequential
from .aev import AEVComputer
......@@ -113,10 +114,10 @@ class BuiltinNet(torch.nn.Module):
index (:class:`int`): Index of the model
Returns:
ret: (:class:`torch.nn.Sequential`): Sequential model ready for
ret: (:class:`Sequential`): Sequential model ready for
calculations
"""
ret = torch.nn.Sequential(
ret = Sequential(
self.aev_computer,
self.neural_networks[index],
self.energy_shifter
......
......@@ -12,7 +12,7 @@ import timeit
from . import _six # noqa:F401
import collections
import sys
from ..nn import ANIModel, Ensemble, Gaussian
from ..nn import ANIModel, Ensemble, Gaussian, Sequential
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..optim import AdamW
......@@ -544,7 +544,7 @@ if sys.version_info[0] > 2:
if self.aev_caching:
self.model = self.nn.to(self.device)
else:
self.model = torch.nn.Sequential(self.aev_computer, self.nn).to(self.device)
self.model = Sequential(self.aev_computer, self.nn).to(self.device)
# loss functions
self.mse_se = torch.nn.MSELoss(reduction='none')
......
......@@ -61,6 +61,25 @@ class Ensemble(torch.nn.ModuleList):
return species, sum(outputs) / len(outputs)
class Sequential(torch.nn.Module):
"""Modified Sequential module that accept Tuple type as input"""
def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], torch.OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
for module in self._modules.values():
input = module(input)
return input
class Gaussian(torch.nn.Module):
"""Gaussian activation"""
def forward(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment