Unverified Commit ce21d224 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Some API improvements (#73)

parent f146feca
...@@ -12,8 +12,8 @@ ensemble = 8 ...@@ -12,8 +12,8 @@ ensemble = 8
consts = torchani.neurochem.Constants(const_file) consts = torchani.neurochem.Constants(const_file)
sae = torchani.neurochem.load_sae(sae_file) sae = torchani.neurochem.load_sae(sae_file)
aev_computer = torchani.AEVComputer(**consts) aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species, from_=network_dir, nn = torchani.neurochem.load_model_ensemble(consts.species, network_dir,
ensemble=ensemble) ensemble)
shift_energy = torchani.EnergyShifter(consts.species, sae) shift_energy = torchani.EnergyShifter(consts.species, sae)
model = torch.nn.Sequential(aev_computer, nn, shift_energy) model = torch.nn.Sequential(aev_computer, nn, shift_energy)
......
...@@ -17,9 +17,7 @@ def atomic(): ...@@ -17,9 +17,7 @@ def atomic():
def get_or_create_model(filename, device=torch.device('cpu')): def get_or_create_model(filename, device=torch.device('cpu')):
consts = torchani.neurochem.Constants() aev_computer = torchani.buildins.aev_computer
sae = torchani.neurochem.load_sae()
aev_computer = torchani.AEVComputer(**consts)
model = torchani.ANIModel([ model = torchani.ANIModel([
('C', atomic()), ('C', atomic()),
('H', atomic()), ('H', atomic()),
...@@ -37,4 +35,4 @@ def get_or_create_model(filename, device=torch.device('cpu')): ...@@ -37,4 +35,4 @@ def get_or_create_model(filename, device=torch.device('cpu')):
model.load_state_dict(torch.load(filename)) model.load_state_dict(torch.load(filename))
else: else:
torch.save(model.state_dict(), filename) torch.save(model.state_dict(), filename)
return model.to(device), torchani.EnergyShifter(consts.species, sae) return model.to(device)
...@@ -19,16 +19,13 @@ parser.add_argument('--batch_size', ...@@ -19,16 +19,13 @@ parser.add_argument('--batch_size',
default=1024, type=int) default=1024, type=int)
parser.add_argument('--const_file', parser.add_argument('--const_file',
help='File storing constants', help='File storing constants',
default=torchani.neurochem.buildin_const_file) default=torchani.buildins.const_file)
parser.add_argument('--sae_file', parser.add_argument('--sae_file',
help='File storing self atomic energies', help='File storing self atomic energies',
default=torchani.neurochem.buildin_sae_file) default=torchani.buildins.sae_file)
parser.add_argument('--network_dir', parser.add_argument('--network_dir',
help='Directory or prefix of directories storing networks', help='Directory or prefix of directories storing networks',
default=None) default=torchani.buildins.ensemble_prefix + '0/networks')
parser.add_argument('--ensemble',
help='Number of models in ensemble',
default=False)
parser = parser.parse_args() parser = parser.parse_args()
# load modules and datasets # load modules and datasets
...@@ -36,9 +33,7 @@ device = torch.device(parser.device) ...@@ -36,9 +33,7 @@ device = torch.device(parser.device)
consts = torchani.neurochem.Constants(parser.const_file) consts = torchani.neurochem.Constants(parser.const_file)
sae = torchani.neurochem.load_sae(parser.sae_file) sae = torchani.neurochem.load_sae(parser.sae_file)
aev_computer = torchani.AEVComputer(**consts) aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species, nn = torchani.neurochem.load_model(consts.species, parser.network_dir)
from_=parser.network_dir,
ensemble=parser.ensemble)
model = torch.nn.Sequential(aev_computer, nn) model = torch.nn.Sequential(aev_computer, nn)
container = torchani.training.Container({'energies': model}) container = torchani.training.Container({'energies': model})
container = container.to(device) container = container.to(device)
......
...@@ -51,8 +51,8 @@ device = torch.device(parser.device) ...@@ -51,8 +51,8 @@ device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log) writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer() start = timeit.default_timer()
nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint, nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
device=device) shift_energy = torchani.buildins.energy_shifter
training, validation, testing = torchani.training.load_or_create( training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.batch_size, nnp[0].species, parser.dataset_checkpoint, parser.batch_size, nnp[0].species,
parser.dataset_path, device=device, parser.dataset_path, device=device,
......
...@@ -21,7 +21,8 @@ parser = parser.parse_args() ...@@ -21,7 +21,8 @@ parser = parser.parse_args()
# set up benchmark # set up benchmark
device = torch.device(parser.device) device = torch.device(parser.device)
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt', device=device) nnp = model.get_or_create_model('/tmp/model.pt', device=device)
shift_energy = torchani.buildins.energy_shifter
dataset = torchani.training.BatchedANIDataset( dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, nnp[0].species, parser.batch_size, device=device, parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset]) transform=[shift_energy.subtract_from_dataset])
......
...@@ -11,8 +11,7 @@ N = 97 ...@@ -11,8 +11,7 @@ N = 97
class TestAEV(unittest.TestCase): class TestAEV(unittest.TestCase):
def setUp(self): def setUp(self):
self.constants = torchani.neurochem.Constants() self.aev_computer = torchani.buildins.aev_computer
self.aev_computer = torchani.AEVComputer(**self.constants)
self.radial_length = self.aev_computer.radial_length() self.radial_length = self.aev_computer.radial_length()
self.tolerance = 1e-5 self.tolerance = 1e-5
...@@ -44,7 +43,7 @@ class TestAEV(unittest.TestCase): ...@@ -44,7 +43,7 @@ class TestAEV(unittest.TestCase):
coordinates, species, radial, angular, _, _ = pickle.load(f) coordinates, species, radial, angular, _, _ = pickle.load(f)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
radial_angular.append((radial, angular)) radial_angular.append((radial, angular))
species, coordinates = torchani.padding.pad_and_batch( species, coordinates = torchani.utils.pad_and_batch(
species_coordinates) species_coordinates)
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
start = 0 start = 0
......
...@@ -6,7 +6,7 @@ import unittest ...@@ -6,7 +6,7 @@ import unittest
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset') dataset_path = os.path.join(path, '../dataset')
batch_size = 256 batch_size = 256
consts = torchani.neurochem.Constants() consts = torchani.buildins.consts
class TestData(unittest.TestCase): class TestData(unittest.TestCase):
...@@ -26,7 +26,7 @@ class TestData(unittest.TestCase): ...@@ -26,7 +26,7 @@ class TestData(unittest.TestCase):
coordinates2 = torch.randn(2, 8, 3) coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long) species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3) coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.padding.pad_and_batch([ species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
(species3, coordinates3), (species3, coordinates3),
...@@ -44,19 +44,19 @@ class TestData(unittest.TestCase): ...@@ -44,19 +44,19 @@ class TestData(unittest.TestCase):
self.assertGreater(conformations, 0) self.assertGreater(conformations, 0)
s_ = species[start:start+conformations, ...] s_ = species[start:start+conformations, ...]
c_ = coordinates[start:start+conformations, ...] c_ = coordinates[start:start+conformations, ...]
s_, c_ = torchani.padding.strip_redundant_padding(s_, c_) s_, c_ = torchani.utils.strip_redundant_padding(s_, c_)
self._assertTensorEqual(s, s_) self._assertTensorEqual(s, s_)
self._assertTensorEqual(c, c_) self._assertTensorEqual(c, c_)
start += conformations start += conformations
s, c = torchani.padding.pad_and_batch(chunks) s, c = torchani.utils.pad_and_batch(chunks)
self._assertTensorEqual(s, species) self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates) self._assertTensorEqual(c, coordinates)
def testTensorShape(self): def testTensorShape(self):
for i in self.ds: for i in self.ds:
input, output = i input, output = i
species, coordinates = torchani.padding.pad_and_batch(input) species, coordinates = torchani.utils.pad_and_batch(input)
energies = output['energies'] energies = output['energies']
self.assertEqual(len(species.shape), 2) self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size) self.assertLessEqual(species.shape[0], batch_size)
......
...@@ -13,11 +13,9 @@ class TestEnergies(unittest.TestCase): ...@@ -13,11 +13,9 @@ class TestEnergies(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 5e-5 self.tolerance = 5e-5
consts = torchani.neurochem.Constants() aev_computer = torchani.buildins.aev_computer
sae = torchani.neurochem.load_sae() nnp = torchani.buildins.models[0]
aev_computer = torchani.AEVComputer(**consts) shift_energy = torchani.buildins.energy_shifter
nnp = torchani.neurochem.load_model(consts.species)
shift_energy = torchani.EnergyShifter(consts.species, sae)
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy) self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
def testIsomers(self): def testIsomers(self):
...@@ -38,7 +36,7 @@ class TestEnergies(unittest.TestCase): ...@@ -38,7 +36,7 @@ class TestEnergies(unittest.TestCase):
coordinates, species, _, _, e, _ = pickle.load(f) coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
energies.append(e) energies.append(e)
species, coordinates = torchani.padding.pad_and_batch( species, coordinates = torchani.utils.pad_and_batch(
species_coordinates) species_coordinates)
energies = torch.cat(energies) energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates)) _, energies_ = self.model((species, coordinates))
......
...@@ -16,22 +16,15 @@ class TestEnsemble(unittest.TestCase): ...@@ -16,22 +16,15 @@ class TestEnsemble(unittest.TestCase):
def _test_molecule(self, coordinates, species): def _test_molecule(self, coordinates, species):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.neurochem.buildin_ensemble aev = torchani.buildins.aev_computer
prefix = torchani.neurochem.buildin_model_prefix ensemble = torchani.buildins.models
consts = torchani.neurochem.Constants() models = [torch.nn.Sequential(aev, m) for m in ensemble]
aev = torchani.AEVComputer(**consts)
ensemble = torchani.neurochem.load_model(consts.species, ensemble=True)
ensemble = torch.nn.Sequential(aev, ensemble) ensemble = torch.nn.Sequential(aev, ensemble)
models = [torchani.neurochem.load_model(
consts.species, ensemble=False,
from_=prefix + '{}/networks/'.format(i))
for i in range(n)]
models = [torch.nn.Sequential(aev, m) for m in models]
_, energy1 = ensemble((species, coordinates)) _, energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0] force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m((species, coordinates))[1] for m in models] energy2 = [m((species, coordinates))[1] for m in models]
energy2 = sum(energy2) / n energy2 = sum(energy2) / len(models)
force2 = torch.autograd.grad(energy2.sum(), coordinates)[0] force2 = torch.autograd.grad(energy2.sum(), coordinates)[0]
energy_diff = (energy1 - energy2).abs().max().item() energy_diff = (energy1 - energy2).abs().max().item()
force_diff = (force1 - force2).abs().max().item() force_diff = (force1 - force2).abs().max().item()
......
...@@ -12,9 +12,8 @@ class TestForce(unittest.TestCase): ...@@ -12,9 +12,8 @@ class TestForce(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 1e-5 self.tolerance = 1e-5
consts = torchani.neurochem.Constants() aev_computer = torchani.buildins.aev_computer
aev_computer = torchani.AEVComputer(**consts) nnp = torchani.buildins.models[0]
nnp = torchani.neurochem.load_model(consts.species)
self.model = torch.nn.Sequential(aev_computer, nnp) self.model = torch.nn.Sequential(aev_computer, nnp)
def testIsomers(self): def testIsomers(self):
...@@ -39,7 +38,7 @@ class TestForce(unittest.TestCase): ...@@ -39,7 +38,7 @@ class TestForce(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates)) species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces)) coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.padding.pad_and_batch( species, coordinates = torchani.utils.pad_and_batch(
species_coordinates) species_coordinates)
_, energies = self.model((species, coordinates)) _, energies = self.model((species, coordinates))
energies = energies.sum() energies = energies.sum()
......
import os import os
import unittest import unittest
import torch import torch
import copy
from ignite.engine import create_supervised_trainer, \ from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator, Events create_supervised_evaluator, Events
import torchani import torchani
...@@ -15,13 +16,11 @@ threshold = 1e-5 ...@@ -15,13 +16,11 @@ threshold = 1e-5
class TestIgnite(unittest.TestCase): class TestIgnite(unittest.TestCase):
def testIgnite(self): def testIgnite(self):
consts = torchani.neurochem.Constants() aev_computer = torchani.buildins.aev_computer
sae = torchani.neurochem.load_sae() nnp = copy.deepcopy(torchani.buildins.models[0])
aev_computer = torchani.AEVComputer(**consts) shift_energy = torchani.buildins.energy_shifter
nnp = torchani.neurochem.load_model(consts.species)
shift_energy = torchani.EnergyShifter(consts.species, sae)
ds = torchani.training.BatchedANIDataset( ds = torchani.training.BatchedANIDataset(
path, consts.species, batchsize, path, torchani.buildins.consts.species, batchsize,
transform=[shift_energy.subtract_from_dataset]) transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0]) ds = torch.utils.data.Subset(ds, [0])
......
...@@ -10,7 +10,7 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -10,7 +10,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([ species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
...@@ -33,7 +33,7 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([ species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
...@@ -62,7 +62,7 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -62,7 +62,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([ species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
...@@ -82,7 +82,7 @@ class TestPadAndBatch(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestPadAndBatch(unittest.TestCase):
def testPresentSpecies(self): def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1]) species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.padding.present_species(species) present_species = torchani.utils.present_species(species)
expected = torch.LongTensor([0, 1, 3, 7]) expected = torch.LongTensor([0, 1, 3, 7])
self.assertEqual((expected - present_species).abs().max().item(), 0) self.assertEqual((expected - present_species).abs().max().item(), 0)
...@@ -97,22 +97,22 @@ class TestStripRedundantPadding(unittest.TestCase): ...@@ -97,22 +97,22 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1 = torch.randn(5, 4, 3) coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long) species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3) coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.padding.pad_and_batch([ species12, coordinates12 = torchani.utils.pad_and_batch([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
]) ])
species3 = torch.randint(4, (2, 10), dtype=torch.long) species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3) coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.padding.pad_and_batch([ species123, coordinates123 = torchani.utils.pad_and_batch([
(species1, coordinates1), (species1, coordinates1),
(species2, coordinates2), (species2, coordinates2),
(species3, coordinates3), (species3, coordinates3),
]) ])
species1_, coordinates1_ = torchani.padding.strip_redundant_padding( species1_, coordinates1_ = torchani.utils.strip_redundant_padding(
species123[:5, ...], coordinates123[:5, ...]) species123[:5, ...], coordinates123[:5, ...])
self._assertTensorEqual(species1_, species1) self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1) self._assertTensorEqual(coordinates1_, coordinates1)
species12_, coordinates12_ = torchani.padding.strip_redundant_padding( species12_, coordinates12_ = torchani.utils.strip_redundant_padding(
species123[:7, ...], coordinates123[:7, ...]) species123[:7, ...], coordinates123[:7, ...])
self._assertTensorEqual(species12_, species12) self._assertTensorEqual(species12_, species12)
self._assertTensorEqual(coordinates12_, coordinates12) self._assertTensorEqual(coordinates12_, coordinates12)
......
from .energyshifter import EnergyShifter from .utils import EnergyShifter
from .models import ANIModel, Ensemble from .models import ANIModel, Ensemble
from .aev import AEVComputer from .aev import AEVComputer
from . import training from . import training
from . import padding from . import utils
from . import neurochem from . import neurochem
from .neurochem import buildins
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', __all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'buildins',
'training', 'padding', 'neurochem'] 'training', 'utils', 'neurochem']
import torch import torch
import itertools import itertools
import math import math
from . import padding from . import utils
def _cutoff_cosine(distances, cutoff): def _cutoff_cosine(distances, cutoff):
...@@ -363,7 +363,7 @@ class AEVComputer(torch.nn.Module): ...@@ -363,7 +363,7 @@ class AEVComputer(torch.nn.Module):
def forward(self, species_coordinates): def forward(self, species_coordinates):
species, coordinates = species_coordinates species, coordinates = species_coordinates
present_species = padding.present_species(species) present_species = utils.present_species(species)
# TODO: remove this workaround after gather support broadcasting # TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1] atoms = coordinates.shape[1]
......
import torch
class EnergyShifter(torch.nn.Module):
def __init__(self, species, self_energies):
super(EnergyShifter, self).__init__()
self_energies_tensor = [self_energies[s] for s in species]
self.register_buffer('self_energies_tensor',
torch.tensor(self_energies_tensor,
dtype=torch.double))
def sae(self, species):
self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties):
dtype = properties['energies'].dtype
device = properties['energies'].device
properties['energies'] -= self.sae(species).to(dtype).to(device)
return species, coordinates, properties
def forward(self, species_energies):
species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae
import torch import torch
from . import padding from . import utils
class ANIModel(torch.nn.Module): class ANIModel(torch.nn.Module):
...@@ -50,7 +50,7 @@ class ANIModel(torch.nn.Module): ...@@ -50,7 +50,7 @@ class ANIModel(torch.nn.Module):
""" """
species, aev = species_aev species, aev = species_aev
species_ = species.flatten() species_ = species.flatten()
present_species = padding.present_species(species) present_species = utils.present_species(species)
aev = aev.flatten(0, 1) aev = aev.flatten(0, 1)
output = torch.full_like(species_, self.padding_fill, output = torch.full_like(species_, self.padding_fill,
......
...@@ -6,26 +6,13 @@ import lark ...@@ -6,26 +6,13 @@ import lark
import struct import struct
from collections.abc import Mapping from collections.abc import Mapping
from .models import ANIModel, Ensemble from .models import ANIModel, Ensemble
from .utils import EnergyShifter
from .aev import AEVComputer
buildin_const_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
buildin_sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
buildin_network_dir = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train0/networks/')
buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
buildin_ensemble = 8
class Constants(Mapping): class Constants(Mapping):
def __init__(self, filename=buildin_const_file): def __init__(self, filename):
self.filename = filename self.filename = filename
with open(filename) as f: with open(filename) as f:
for i in f: for i in f:
...@@ -73,7 +60,7 @@ class Constants(Mapping): ...@@ -73,7 +60,7 @@ class Constants(Mapping):
return torch.tensor(rev, dtype=torch.long, device=device) return torch.tensor(rev, dtype=torch.long, device=device)
def load_sae(filename=buildin_sae_file): def load_sae(filename):
"""Load self energies from NeuroChem sae file""" """Load self energies from NeuroChem sae file"""
self_energies = {} self_energies = {}
with open(filename) as f: with open(filename) as f:
...@@ -241,37 +228,42 @@ def load_atomic_network(filename): ...@@ -241,37 +228,42 @@ def load_atomic_network(filename):
return torch.nn.Sequential(*layers) return torch.nn.Sequential(*layers)
def load_model(species, from_=None, ensemble=False): def load_model(species, from_):
"""If from_=None then ensemble must be a boolean. If ensemble=False,
then use buildin network0, else use buildin network ensemble.
If from_ != None, ensemble must be either False or an integer
specifying the number of networks in the ensemble.
"""
if from_ is None:
if not isinstance(ensemble, bool):
raise TypeError('ensemble must be boolean')
if ensemble:
from_ = buildin_model_prefix
ensemble = buildin_ensemble
else:
from_ = buildin_network_dir
else:
if not (ensemble is False or isinstance(ensemble, int)):
raise ValueError('invalid argument ensemble')
def load_single_model(from_):
models = [] models = []
for i in species: for i in species:
filename = os.path.join(from_, 'ANN-{}.nnf'.format(i)) filename = os.path.join(from_, 'ANN-{}.nnf'.format(i))
models.append((i, load_atomic_network(filename))) models.append((i, load_atomic_network(filename)))
return ANIModel(models) return ANIModel(models)
if ensemble is False:
return load_single_model(from_) def load_model_ensemble(species, prefix, count):
else:
assert isinstance(ensemble, int)
models = [] models = []
for i in range(ensemble): for i in range(count):
network_dir = os.path.join('{}{}'.format(from_, i), 'networks') network_dir = os.path.join('{}{}'.format(prefix, i), 'networks')
models.append(load_single_model(network_dir)) models.append(load_model(species, network_dir))
return Ensemble(models) return Ensemble(models)
class Buildins:
def __init__(self):
self.const_file = pkg_resources.resource_filename(
__name__,
'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
self.consts = Constants(self.const_file)
self.aev_computer = AEVComputer(**self.consts)
self.sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
self.energy_shifter = EnergyShifter(self.consts.species,
load_sae(self.sae_file))
self.ensemble_size = 8
self.ensemble_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
self.models = load_model_ensemble(self.consts.species,
self.ensemble_prefix,
self.ensemble_size)
buildins = Buildins()
import torch import torch
from .. import padding from .. import utils
class Container(torch.nn.Module): class Container(torch.nn.Module):
...@@ -18,7 +18,7 @@ class Container(torch.nn.Module): ...@@ -18,7 +18,7 @@ class Container(torch.nn.Module):
_, result = model(sc) _, result = model(sc)
results[k].append(result) results[k].append(result)
results['species'], results['coordinates'] = \ results['species'], results['coordinates'] = \
padding.pad_and_batch(species_coordinates) utils.pad_and_batch(species_coordinates)
for k in self.keys: for k in self.keys:
results[k] = torch.cat(results[k]) results[k] = torch.cat(results[k])
return results return results
...@@ -5,7 +5,7 @@ from .pyanitools import anidataloader ...@@ -5,7 +5,7 @@ from .pyanitools import anidataloader
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import pickle import pickle
from .. import padding from .. import utils
def chunk_counts(counts, split): def chunk_counts(counts, split):
...@@ -69,7 +69,7 @@ def split_batch(natoms, species, coordinates): ...@@ -69,7 +69,7 @@ def split_batch(natoms, species, coordinates):
end = start + i end = start + i
s = species[start:end, ...] s = species[start:end, ...]
c = coordinates[start:end, ...] c = coordinates[start:end, ...]
s, c = padding.strip_redundant_padding(s, c) s, c = utils.strip_redundant_padding(s, c)
species_coordinates.append((s, c)) species_coordinates.append((s, c))
start = end start = end
return species_coordinates return species_coordinates
...@@ -119,7 +119,7 @@ class BatchedANIDataset(Dataset): ...@@ -119,7 +119,7 @@ class BatchedANIDataset(Dataset):
for i in properties: for i in properties:
properties[i].append(torch.from_numpy(m[i]) properties[i].append(torch.from_numpy(m[i])
.type(dtype).to(device)) .type(dtype).to(device))
species, coordinates = padding.pad_and_batch(species_coordinates) species, coordinates = utils.pad_and_batch(species_coordinates)
for i in properties: for i in properties:
properties[i] = torch.cat(properties[i]) properties[i] = torch.cat(properties[i])
......
...@@ -34,3 +34,29 @@ def strip_redundant_padding(species, coordinates): ...@@ -34,3 +34,29 @@ def strip_redundant_padding(species, coordinates):
species = species.index_select(1, non_padding) species = species.index_select(1, non_padding)
coordinates = coordinates.index_select(1, non_padding) coordinates = coordinates.index_select(1, non_padding)
return species, coordinates return species, coordinates
class EnergyShifter(torch.nn.Module):
def __init__(self, species, self_energies):
super(EnergyShifter, self).__init__()
self_energies_tensor = [self_energies[s] for s in species]
self.register_buffer('self_energies_tensor',
torch.tensor(self_energies_tensor,
dtype=torch.double))
def sae(self, species):
self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties):
dtype = properties['energies'].dtype
device = properties['energies'].device
properties['energies'] -= self.sae(species).to(dtype).to(device)
return species, coordinates, properties
def forward(self, species_energies):
species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment