Unverified Commit 73e447f0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

improve energy shifter (#52)

parent 59b31d84
......@@ -12,8 +12,8 @@ aev_computer = torchani.SortedAEV(const_file=const_file)
prepare = torchani.PrepareInput(aev_computer.species)
nn = torchani.models.NeuroChemNNP(aev_computer.species, from_=network_dir,
ensemble=8)
model = torch.nn.Sequential(prepare, aev_computer, nn)
shift_energy = torchani.EnergyShifter(sae_file)
shift_energy = torchani.EnergyShifter(aev_computer.species, sae_file)
model = torch.nn.Sequential(prepare, aev_computer, nn, shift_energy)
coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
......@@ -25,7 +25,6 @@ species = ['C', 'H', 'H', 'H', 'H']
_, energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
energy = shift_energy.add_sae(energy, species)
force = -derivative
print('Energy:', energy.item())
......
......@@ -41,4 +41,4 @@ def get_or_create_model(filename, benchmark=False,
model.load_state_dict(torch.load(filename))
else:
torch.save(model.state_dict(), filename)
return model.to(device)
return model.to(device), torchani.EnergyShifter(aev_computer.species)
......@@ -48,13 +48,13 @@ device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer()
shift_energy = torchani.EnergyShifter()
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
True, device=device)
training, validation, testing = torchani.data.load_or_create(
parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size,
device=device, transform=[shift_energy.dataset_subtract_sae])
device=device, transform=[shift_energy.subtract_from_dataset])
training = torchani.data.dataloader(training, parser.batch_chunks)
validation = torchani.data.dataloader(validation, parser.batch_chunks)
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
container = torchani.ignite.Container({'energies': nnp})
parser.optim_args = json.loads(parser.optim_args)
......
......@@ -24,12 +24,12 @@ parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
shift_energy = torchani.EnergyShifter()
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
True, device=device)
dataset = torchani.data.ANIDataset(
parser.dataset_path, parser.chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae])
transform=[shift_energy.subtract_from_dataset])
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
......
......@@ -16,13 +16,13 @@ class TestEnergies(unittest.TestCase):
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
shift_energy = torchani.EnergyShifter(aev_computer.species)
self.model = torch.nn.Sequential(prepare, aev_computer,
nnp, shift_energy)
def _test_molecule(self, coordinates, species, energies):
shift_energy = torchani.EnergyShifter()
_, energies_ = self.model((species, coordinates))
energies_ = shift_energy.add_sae(energies_.squeeze(), species)
max_diff = (energies - energies_).abs().max().item()
max_diff = (energies - energies_.squeeze()).abs().max().item()
self.assertLess(max_diff, self.tolerance)
def testGDB(self):
......
import torch
import torchani
import unittest
import random
class TestEnergyShifter(unittest.TestCase):
def setUp(self):
self.tol = 1e-5
self.species = torchani.SortedAEV().species
self.prepare = torchani.PrepareInput(self.species)
self.shift_energy = torchani.EnergyShifter(self.species)
def testSAEMatch(self):
for _ in range(10):
k = random.choice(range(5, 30))
species = random.choices(self.species, k=k)
species_tensor = self.prepare.species_to_tensor(
species, torch.device('cpu'))
e1 = self.shift_energy.sae_from_list(species)
e2 = self.shift_energy.sae_from_tensor(species_tensor)
self.assertLess(abs(e1 - e2), self.tol)
if __name__ == '__main__':
unittest.main()
......@@ -17,14 +17,15 @@ if sys.version_info.major >= 3:
class TestIgnite(unittest.TestCase):
def testIgnite(self):
shift_energy = torchani.EnergyShifter()
ds = torchani.data.ANIDataset(
path, chunksize, transform=[shift_energy.dataset_subtract_sae])
ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
ds = torchani.data.ANIDataset(
path, chunksize,
transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1)
class Flatten(torch.nn.Module):
def forward(self, x):
......
import torch
from .env import buildin_sae_file
class EnergyShifter:
"""Class that deal with self atomic energies.
class EnergyShifter(torch.nn.Module):
Attributes
----------
self_energies : dict
The dictionary that stores self energies of species.
"""
def __init__(self, self_energy_file=buildin_sae_file):
def __init__(self, species, self_energy_file=buildin_sae_file):
super(EnergyShifter, self).__init__()
# load self energies
self.self_energies = {}
with open(self_energy_file) as f:
......@@ -22,55 +17,24 @@ class EnergyShifter:
self.self_energies[name] = value
except Exception:
pass # ignore unrecognizable line
self_energies_tensor = [self.self_energies[s] for s in species]
self.register_buffer('self_energies_tensor',
torch.tensor(self_energies_tensor,
dtype=torch.double))
def subtract_sae(self, energies, species):
"""Subtract self atomic energies from `energies`.
Parameters
----------
energies : pytorch tensor of `dtype`
The tensor of any shape that stores the raw energies.
species : list of str
The list specifying the species of each atom. The length of the
list must be the same as the number of atoms.
Returns
-------
pytorch tensor of `dtype`
The tensor of the same shape as `energies` that stores the energies
with self atomic energies subtracted.
"""
s = 0
for i in species:
s += self.self_energies[i]
return energies - s
def sae_from_list(self, species):
energies = [self.self_energies[i] for i in species]
return sum(energies)
def add_sae(self, energies, species):
"""Add self atomic energies to `energies`
def sae_from_tensor(self, species):
return self.self_energies_tensor[species].sum().item()
Parameters
----------
energies : pytorch tensor of `dtype`
The tensor of any shape that stores the energies excluding self
atomic energies.
species : list of str
The list specifying the species of each atom. The length of the
list must be the same as the number of atoms.
Returns
-------
pytorch tensor of `dtype`
The tensor of the same shape as `energies` that stores the raw
energies, i.e. the energy including self atomic energies.
"""
s = 0
for i in species:
s += self.self_energies[i]
return energies + s
def dataset_subtract_sae(self, data):
"""Allow object of this class to be used as transforms of pytorch's
dataset.
"""
data['energies'] = self.subtract_sae(data['energies'], data['species'])
def subtract_from_dataset(self, data):
sae = self.sae_from_list(data['species'])
data['energies'] -= sae
return data
def forward(self, species_energies):
species, energies = species_energies
sae = self.sae_from_tensor(species)
return species, energies + sae
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment