"vscode:/vscode.git/clone" did not exist on "3033f08201e4ec46d87c37098b066b8b1924b052"
Unverified Commit 379d2b33 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

[JIT] Add TorchScript Compatibility for EnergyShifter (#306)

* enable EnergyShifter scripting

* fix

* fix
parent f2170e24
...@@ -16,10 +16,10 @@ class TestEnergies(unittest.TestCase): ...@@ -16,10 +16,10 @@ class TestEnergies(unittest.TestCase):
self.tolerance = 5e-5 self.tolerance = 5e-5
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0] self.nnp = ani1x.neural_networks[0]
shift_energy = ani1x.energy_shifter self.energy_shifter = ani1x.energy_shifter
self.nn = torch.nn.Sequential(nnp, shift_energy) self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, nnp, shift_energy) self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
def random_skip(self): def random_skip(self):
return False return False
...@@ -116,5 +116,13 @@ class TestEnergies(unittest.TestCase): ...@@ -116,5 +116,13 @@ class TestEnergies(unittest.TestCase):
self.assertLess(max_diff / math.sqrt(natoms), self.tolerance) self.assertLess(max_diff / math.sqrt(natoms), self.tolerance)
class TestEnergiesEnergyShifterJIT(TestEnergies):
def setUp(self):
super().setUp()
self.energy_shifter = torch.jit.script(self.energy_shifter)
self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -3,6 +3,7 @@ import torch.utils.data ...@@ -3,6 +3,7 @@ import torch.utils.data
import math import math
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Tuple
def pad(species): def pad(species):
...@@ -191,7 +192,7 @@ class EnergyShifter(torch.nn.Module): ...@@ -191,7 +192,7 @@ class EnergyShifter(torch.nn.Module):
intercept = self.self_energies[-1] intercept = self.self_energies[-1]
self_energies = self.self_energies[species] self_energies = self.self_energies[species]
self_energies[species == -1] = 0 self_energies[species == torch.tensor(-1)] = torch.tensor(0)
return self_energies.sum(dim=1) + intercept return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties): def subtract_from_dataset(self, atomic_properties, properties):
...@@ -210,6 +211,7 @@ class EnergyShifter(torch.nn.Module): ...@@ -210,6 +211,7 @@ class EnergyShifter(torch.nn.Module):
return atomic_properties, properties return atomic_properties, properties
def forward(self, species_energies): def forward(self, species_energies):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""(species, molecular energies)->(species, molecular energies + sae) """(species, molecular energies)->(species, molecular energies + sae)
""" """
species, energies = species_energies species, energies = species_energies
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment