Unverified Commit 379d2b33 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

[JIT] Add TorchScript Compatibility for EnergyShifter (#306)

* enable EnergyShifter scripting

* fix

* fix
parent f2170e24
......@@ -16,10 +16,10 @@ class TestEnergies(unittest.TestCase):
self.tolerance = 5e-5
ani1x = torchani.models.ANI1x()
self.aev_computer = ani1x.aev_computer
nnp = ani1x.neural_networks[0]
shift_energy = ani1x.energy_shifter
self.nn = torch.nn.Sequential(nnp, shift_energy)
self.model = torch.nn.Sequential(self.aev_computer, nnp, shift_energy)
self.nnp = ani1x.neural_networks[0]
self.energy_shifter = ani1x.energy_shifter
self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
def random_skip(self):
return False
......@@ -116,5 +116,13 @@ class TestEnergies(unittest.TestCase):
self.assertLess(max_diff / math.sqrt(natoms), self.tolerance)
class TestEnergiesEnergyShifterJIT(TestEnergies):
def setUp(self):
super().setUp()
self.energy_shifter = torch.jit.script(self.energy_shifter)
self.nn = torch.nn.Sequential(self.nnp, self.energy_shifter)
self.model = torch.nn.Sequential(self.aev_computer, self.nnp, self.energy_shifter)
if __name__ == '__main__':
unittest.main()
......@@ -3,6 +3,7 @@ import torch.utils.data
import math
import numpy as np
from collections import defaultdict
from typing import Tuple
def pad(species):
......@@ -191,7 +192,7 @@ class EnergyShifter(torch.nn.Module):
intercept = self.self_energies[-1]
self_energies = self.self_energies[species]
self_energies[species == -1] = 0
self_energies[species == torch.tensor(-1)] = torch.tensor(0)
return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties):
......@@ -210,6 +211,7 @@ class EnergyShifter(torch.nn.Module):
return atomic_properties, properties
def forward(self, species_energies):
# type: (Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
"""(species, molecular energies)->(species, molecular energies + sae)
"""
species, energies = species_energies
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment