Unverified Commit 465faa00 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

Use a modified Sequential class with type annotation (#308)

parent 9715749f
...@@ -48,8 +48,8 @@ model = torchani.neurochem.load_model(consts.species, model_dir) ...@@ -48,8 +48,8 @@ model = torchani.neurochem.load_model(consts.species, model_dir)
# (Coordinates) -[AEVComputer]-> (AEV) -[Neural Network]-> # (Coordinates) -[AEVComputer]-> (AEV) -[Neural Network]->
# (Raw energies) -[EnergyShifter]-> (Final energies) # (Raw energies) -[EnergyShifter]-> (Final energies)
# From using either the ensemble or a single model: # From using either the ensemble or a single model:
nnp1 = torch.nn.Sequential(aev_computer, ensemble, energy_shifter) nnp1 = torchani.nn.Sequential(aev_computer, ensemble, energy_shifter)
nnp2 = torch.nn.Sequential(aev_computer, model, energy_shifter) nnp2 = torchani.nn.Sequential(aev_computer, model, energy_shifter)
print(nnp1) print(nnp1)
print(nnp2) print(nnp2)
......
...@@ -180,7 +180,7 @@ nn.apply(init_params) ...@@ -180,7 +180,7 @@ nn.apply(init_params)
############################################################################### ###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks. # Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device) model = torchani.nn.Sequential(aev_computer, nn).to(device)
############################################################################### ###############################################################################
# Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay # Now let's setup the optimizers. NeuroChem uses Adam with decoupled weight decay
......
...@@ -143,7 +143,7 @@ nn.apply(init_params) ...@@ -143,7 +143,7 @@ nn.apply(init_params)
############################################################################### ###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks. # Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device) model = torchani.nn.Sequential(aev_computer, nn).to(device)
############################################################################### ###############################################################################
# Here we will use Adam with weight decay for the weights and Stochastic Gradient # Here we will use Adam with weight decay for the weights and Stochastic Gradient
......
...@@ -93,7 +93,7 @@ else: ...@@ -93,7 +93,7 @@ else:
############################################################################### ###############################################################################
# Let's now create a pipeline of AEV Computer --> Neural Networks. # Let's now create a pipeline of AEV Computer --> Neural Networks.
model = torch.nn.Sequential(aev_computer, nn).to(device) model = torchani.nn.Sequential(aev_computer, nn).to(device)
############################################################################### ###############################################################################
# Now setup tensorboard # Now setup tensorboard
......
...@@ -18,8 +18,8 @@ class TestEnergies(unittest.TestCase): ...@@ -18,8 +18,8 @@ class TestEnergies(unittest.TestCase):
self.aev_computer = ani1x.aev_computer self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0] self.nnp = ani1x.neural_networks[0]
self.energy_shifter = ani1x.energy_shifter self.energy_shifter = ani1x.energy_shifter
self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter) self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter) self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
def random_skip(self): def random_skip(self):
return False return False
...@@ -120,8 +120,8 @@ class TestEnergiesEnergyShifterJIT(TestEnergies): ...@@ -120,8 +120,8 @@ class TestEnergiesEnergyShifterJIT(TestEnergies):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
self.energy_shifter = torch.jit.script(self.energy_shifter) self.energy_shifter = torch.jit.script(self.energy_shifter)
self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter) self.nn = torchani.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter) self.model = torchani.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,7 +15,7 @@ class TestForce(unittest.TestCase): ...@@ -15,7 +15,7 @@ class TestForce(unittest.TestCase):
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer self.aev_computer = ani1x.aev_computer
self.nnp = ani1x.neural_networks[0] self.nnp = ani1x.neural_networks[0]
self.model = torch.nn.Sequential(self.aev_computer, self.nnp) self.model = torchani.nn.Sequential(self.aev_computer, self.nnp)
def random_skip(self): def random_skip(self):
return False return False
......
...@@ -30,7 +30,7 @@ class TestIgnite(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestIgnite(unittest.TestCase):
def forward(self, x): def forward(self, x):
return x[0], x[1].flatten() return x[0], x[1].flatten()
model = torch.nn.Sequential(aev_computer, nnp, Flatten()) model = torchani.nn.Sequential(aev_computer, nnp, Flatten())
container = torchani.ignite.Container({'energies': model}) container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters()) optimizer = torch.optim.Adam(container.parameters())
loss = torchani.ignite.TransformedLoss( loss = torchani.ignite.TransformedLoss(
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import torch import torch
from .nn import Sequential
import ase.neighborlist import ase.neighborlist
from . import utils from . import utils
import ase.calculators.calculator import ase.calculators.calculator
...@@ -47,7 +48,7 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -47,7 +48,7 @@ class Calculator(ase.calculators.calculator.Calculator):
self.device = self.aev_computer.EtaR.device self.device = self.aev_computer.EtaR.device
self.dtype = dtype self.dtype = dtype
self.nn = torch.nn.Sequential( self.nn = Sequential(
self.model, self.model,
self.energy_shifter self.energy_shifter
).to(dtype) ).to(dtype)
......
...@@ -31,6 +31,7 @@ import torch ...@@ -31,6 +31,7 @@ import torch
from typing import Tuple from typing import Tuple
from pkg_resources import resource_filename from pkg_resources import resource_filename
from . import neurochem from . import neurochem
from .nn import Sequential
from .aev import AEVComputer from .aev import AEVComputer
...@@ -113,10 +114,10 @@ class BuiltinNet(torch.nn.Module): ...@@ -113,10 +114,10 @@ class BuiltinNet(torch.nn.Module):
index (:class:`int`): Index of the model index (:class:`int`): Index of the model
Returns: Returns:
ret: (:class:`torch.nn.Sequential`): Sequential model ready for ret: (:class:`Sequential`): Sequential model ready for
calculations calculations
""" """
ret = torch.nn.Sequential( ret = Sequential(
self.aev_computer, self.aev_computer,
self.neural_networks[index], self.neural_networks[index],
self.energy_shifter self.energy_shifter
......
...@@ -12,7 +12,7 @@ import timeit ...@@ -12,7 +12,7 @@ import timeit
from . import _six # noqa:F401 from . import _six # noqa:F401
import collections import collections
import sys import sys
from ..nn import ANIModel, Ensemble, Gaussian from ..nn import ANIModel, Ensemble, Gaussian, Sequential
from ..utils import EnergyShifter, ChemicalSymbolsToInts from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer from ..aev import AEVComputer
from ..optim import AdamW from ..optim import AdamW
...@@ -544,7 +544,7 @@ if sys.version_info[0] > 2: ...@@ -544,7 +544,7 @@ if sys.version_info[0] > 2:
if self.aev_caching: if self.aev_caching:
self.model = self.nn.to(self.device) self.model = self.nn.to(self.device)
else: else:
self.model = torch.nn.Sequential(self.aev_computer, self.nn).to(self.device) self.model = Sequential(self.aev_computer, self.nn).to(self.device)
# loss functions # loss functions
self.mse_se = torch.nn.MSELoss(reduction='none') self.mse_se = torch.nn.MSELoss(reduction='none')
......
...@@ -61,6 +61,25 @@ class Ensemble(torch.nn.ModuleList): ...@@ -61,6 +61,25 @@ class Ensemble(torch.nn.ModuleList):
return species, sum(outputs) / len(outputs) return species, sum(outputs) / len(outputs)
class Sequential(torch.nn.Module):
"""Modified Sequential module that accept Tuple type as input"""
def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], torch.OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
for module in self._modules.values():
input = module(input)
return input
class Gaussian(torch.nn.Module): class Gaussian(torch.nn.Module):
"""Gaussian activation""" """Gaussian activation"""
def forward(self, x): def forward(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment