Unverified Commit ce21d224 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Some API improvements (#73)

parent f146feca
......@@ -12,8 +12,8 @@ ensemble = 8
consts = torchani.neurochem.Constants(const_file)
sae = torchani.neurochem.load_sae(sae_file)
aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species, from_=network_dir,
ensemble=ensemble)
nn = torchani.neurochem.load_model_ensemble(consts.species, network_dir,
ensemble)
shift_energy = torchani.EnergyShifter(consts.species, sae)
model = torch.nn.Sequential(aev_computer, nn, shift_energy)
......
......@@ -17,9 +17,7 @@ def atomic():
def get_or_create_model(filename, device=torch.device('cpu')):
consts = torchani.neurochem.Constants()
sae = torchani.neurochem.load_sae()
aev_computer = torchani.AEVComputer(**consts)
aev_computer = torchani.buildins.aev_computer
model = torchani.ANIModel([
('C', atomic()),
('H', atomic()),
......@@ -37,4 +35,4 @@ def get_or_create_model(filename, device=torch.device('cpu')):
model.load_state_dict(torch.load(filename))
else:
torch.save(model.state_dict(), filename)
return model.to(device), torchani.EnergyShifter(consts.species, sae)
return model.to(device)
......@@ -19,16 +19,13 @@ parser.add_argument('--batch_size',
default=1024, type=int)
parser.add_argument('--const_file',
help='File storing constants',
default=torchani.neurochem.buildin_const_file)
default=torchani.buildins.const_file)
parser.add_argument('--sae_file',
help='File storing self atomic energies',
default=torchani.neurochem.buildin_sae_file)
default=torchani.buildins.sae_file)
parser.add_argument('--network_dir',
help='Directory or prefix of directories storing networks',
default=None)
parser.add_argument('--ensemble',
help='Number of models in ensemble',
default=False)
default=torchani.buildins.ensemble_prefix + '0/networks')
parser = parser.parse_args()
# load modules and datasets
......@@ -36,9 +33,7 @@ device = torch.device(parser.device)
consts = torchani.neurochem.Constants(parser.const_file)
sae = torchani.neurochem.load_sae(parser.sae_file)
aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species,
from_=parser.network_dir,
ensemble=parser.ensemble)
nn = torchani.neurochem.load_model(consts.species, parser.network_dir)
model = torch.nn.Sequential(aev_computer, nn)
container = torchani.training.Container({'energies': model})
container = container.to(device)
......
......@@ -51,8 +51,8 @@ device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer()
nnp, shift_energy = model.get_or_create_model(parser.model_checkpoint,
device=device)
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
shift_energy = torchani.buildins.energy_shifter
training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.batch_size, nnp[0].species,
parser.dataset_path, device=device,
......
......@@ -21,7 +21,8 @@ parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt', device=device)
nnp = model.get_or_create_model('/tmp/model.pt', device=device)
shift_energy = torchani.buildins.energy_shifter
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
......
......@@ -11,8 +11,7 @@ N = 97
class TestAEV(unittest.TestCase):
def setUp(self):
self.constants = torchani.neurochem.Constants()
self.aev_computer = torchani.AEVComputer(**self.constants)
self.aev_computer = torchani.buildins.aev_computer
self.radial_length = self.aev_computer.radial_length()
self.tolerance = 1e-5
......@@ -44,7 +43,7 @@ class TestAEV(unittest.TestCase):
coordinates, species, radial, angular, _, _ = pickle.load(f)
species_coordinates.append((species, coordinates))
radial_angular.append((radial, angular))
species, coordinates = torchani.padding.pad_and_batch(
species, coordinates = torchani.utils.pad_and_batch(
species_coordinates)
_, aev = self.aev_computer((species, coordinates))
start = 0
......
......@@ -6,7 +6,7 @@ import unittest
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset')
batch_size = 256
consts = torchani.neurochem.Constants()
consts = torchani.buildins.consts
class TestData(unittest.TestCase):
......@@ -26,7 +26,7 @@ class TestData(unittest.TestCase):
coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.padding.pad_and_batch([
species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
......@@ -44,19 +44,19 @@ class TestData(unittest.TestCase):
self.assertGreater(conformations, 0)
s_ = species[start:start+conformations, ...]
c_ = coordinates[start:start+conformations, ...]
s_, c_ = torchani.padding.strip_redundant_padding(s_, c_)
s_, c_ = torchani.utils.strip_redundant_padding(s_, c_)
self._assertTensorEqual(s, s_)
self._assertTensorEqual(c, c_)
start += conformations
s, c = torchani.padding.pad_and_batch(chunks)
s, c = torchani.utils.pad_and_batch(chunks)
self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates)
def testTensorShape(self):
for i in self.ds:
input, output = i
species, coordinates = torchani.padding.pad_and_batch(input)
species, coordinates = torchani.utils.pad_and_batch(input)
energies = output['energies']
self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size)
......
......@@ -13,11 +13,9 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
consts = torchani.neurochem.Constants()
sae = torchani.neurochem.load_sae()
aev_computer = torchani.AEVComputer(**consts)
nnp = torchani.neurochem.load_model(consts.species)
shift_energy = torchani.EnergyShifter(consts.species, sae)
aev_computer = torchani.buildins.aev_computer
nnp = torchani.buildins.models[0]
shift_energy = torchani.buildins.energy_shifter
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
def testIsomers(self):
......@@ -38,7 +36,7 @@ class TestEnergies(unittest.TestCase):
coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append((species, coordinates))
energies.append(e)
species, coordinates = torchani.padding.pad_and_batch(
species, coordinates = torchani.utils.pad_and_batch(
species_coordinates)
energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates))
......
......@@ -16,22 +16,15 @@ class TestEnsemble(unittest.TestCase):
def _test_molecule(self, coordinates, species):
coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.neurochem.buildin_ensemble
prefix = torchani.neurochem.buildin_model_prefix
consts = torchani.neurochem.Constants()
aev = torchani.AEVComputer(**consts)
ensemble = torchani.neurochem.load_model(consts.species, ensemble=True)
aev = torchani.buildins.aev_computer
ensemble = torchani.buildins.models
models = [torch.nn.Sequential(aev, m) for m in ensemble]
ensemble = torch.nn.Sequential(aev, ensemble)
models = [torchani.neurochem.load_model(
consts.species, ensemble=False,
from_=prefix + '{}/networks/'.format(i))
for i in range(n)]
models = [torch.nn.Sequential(aev, m) for m in models]
_, energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m((species, coordinates))[1] for m in models]
energy2 = sum(energy2) / n
energy2 = sum(energy2) / len(models)
force2 = torch.autograd.grad(energy2.sum(), coordinates)[0]
energy_diff = (energy1 - energy2).abs().max().item()
force_diff = (force1 - force2).abs().max().item()
......
......@@ -12,9 +12,8 @@ class TestForce(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-5
consts = torchani.neurochem.Constants()
aev_computer = torchani.AEVComputer(**consts)
nnp = torchani.neurochem.load_model(consts.species)
aev_computer = torchani.buildins.aev_computer
nnp = torchani.buildins.models[0]
self.model = torch.nn.Sequential(aev_computer, nnp)
def testIsomers(self):
......@@ -39,7 +38,7 @@ class TestForce(unittest.TestCase):
coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.padding.pad_and_batch(
species, coordinates = torchani.utils.pad_and_batch(
species_coordinates)
_, energies = self.model((species, coordinates))
energies = energies.sum()
......
import os
import unittest
import torch
import copy
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator, Events
import torchani
......@@ -15,13 +16,11 @@ threshold = 1e-5
class TestIgnite(unittest.TestCase):
def testIgnite(self):
consts = torchani.neurochem.Constants()
sae = torchani.neurochem.load_sae()
aev_computer = torchani.AEVComputer(**consts)
nnp = torchani.neurochem.load_model(consts.species)
shift_energy = torchani.EnergyShifter(consts.species, sae)
aev_computer = torchani.buildins.aev_computer
nnp = copy.deepcopy(torchani.buildins.models[0])
shift_energy = torchani.buildins.energy_shifter
ds = torchani.training.BatchedANIDataset(
path, consts.species, batchsize,
path, torchani.buildins.consts.species, batchsize,
transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0])
......
......@@ -10,7 +10,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -33,7 +33,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -62,7 +62,7 @@ class TestPadAndBatch(unittest.TestCase):
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.padding.pad_and_batch([
species, coordinates = torchani.utils.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
......@@ -82,7 +82,7 @@ class TestPadAndBatch(unittest.TestCase):
def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.padding.present_species(species)
present_species = torchani.utils.present_species(species)
expected = torch.LongTensor([0, 1, 3, 7])
self.assertEqual((expected - present_species).abs().max().item(), 0)
......@@ -97,22 +97,22 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.padding.pad_and_batch([
species12, coordinates12 = torchani.utils.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
])
species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.padding.pad_and_batch([
species123, coordinates123 = torchani.utils.pad_and_batch([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
])
species1_, coordinates1_ = torchani.padding.strip_redundant_padding(
species1_, coordinates1_ = torchani.utils.strip_redundant_padding(
species123[:5, ...], coordinates123[:5, ...])
self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1)
species12_, coordinates12_ = torchani.padding.strip_redundant_padding(
species12_, coordinates12_ = torchani.utils.strip_redundant_padding(
species123[:7, ...], coordinates123[:7, ...])
self._assertTensorEqual(species12_, species12)
self._assertTensorEqual(coordinates12_, coordinates12)
......
from .energyshifter import EnergyShifter
from .utils import EnergyShifter
from .models import ANIModel, Ensemble
from .aev import AEVComputer
from . import training
from . import padding
from . import utils
from . import neurochem
from .neurochem import buildins
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble',
'training', 'padding', 'neurochem']
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'buildins',
'training', 'utils', 'neurochem']
import torch
import itertools
import math
from . import padding
from . import utils
def _cutoff_cosine(distances, cutoff):
......@@ -363,7 +363,7 @@ class AEVComputer(torch.nn.Module):
def forward(self, species_coordinates):
species, coordinates = species_coordinates
present_species = padding.present_species(species)
present_species = utils.present_species(species)
# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
......
import torch
class EnergyShifter(torch.nn.Module):
def __init__(self, species, self_energies):
super(EnergyShifter, self).__init__()
self_energies_tensor = [self_energies[s] for s in species]
self.register_buffer('self_energies_tensor',
torch.tensor(self_energies_tensor,
dtype=torch.double))
def sae(self, species):
self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties):
dtype = properties['energies'].dtype
device = properties['energies'].device
properties['energies'] -= self.sae(species).to(dtype).to(device)
return species, coordinates, properties
def forward(self, species_energies):
species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae
import torch
from . import padding
from . import utils
class ANIModel(torch.nn.Module):
......@@ -50,7 +50,7 @@ class ANIModel(torch.nn.Module):
"""
species, aev = species_aev
species_ = species.flatten()
present_species = padding.present_species(species)
present_species = utils.present_species(species)
aev = aev.flatten(0, 1)
output = torch.full_like(species_, self.padding_fill,
......
......@@ -6,26 +6,13 @@ import lark
import struct
from collections.abc import Mapping
from .models import ANIModel, Ensemble
buildin_const_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
buildin_sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
buildin_network_dir = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train0/networks/')
buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
buildin_ensemble = 8
from .utils import EnergyShifter
from .aev import AEVComputer
class Constants(Mapping):
def __init__(self, filename=buildin_const_file):
def __init__(self, filename):
self.filename = filename
with open(filename) as f:
for i in f:
......@@ -73,7 +60,7 @@ class Constants(Mapping):
return torch.tensor(rev, dtype=torch.long, device=device)
def load_sae(filename=buildin_sae_file):
def load_sae(filename):
"""Load self energies from NeuroChem sae file"""
self_energies = {}
with open(filename) as f:
......@@ -241,37 +228,42 @@ def load_atomic_network(filename):
return torch.nn.Sequential(*layers)
def load_model(species, from_=None, ensemble=False):
"""If from_=None then ensemble must be a boolean. If ensemble=False,
then use buildin network0, else use buildin network ensemble.
If from_ != None, ensemble must be either False or an integer
specifying the number of networks in the ensemble.
"""
if from_ is None:
if not isinstance(ensemble, bool):
raise TypeError('ensemble must be boolean')
if ensemble:
from_ = buildin_model_prefix
ensemble = buildin_ensemble
else:
from_ = buildin_network_dir
else:
if not (ensemble is False or isinstance(ensemble, int)):
raise ValueError('invalid argument ensemble')
def load_single_model(from_):
models = []
for i in species:
filename = os.path.join(from_, 'ANN-{}.nnf'.format(i))
models.append((i, load_atomic_network(filename)))
return ANIModel(models)
if ensemble is False:
return load_single_model(from_)
else:
assert isinstance(ensemble, int)
models = []
for i in range(ensemble):
network_dir = os.path.join('{}{}'.format(from_, i), 'networks')
models.append(load_single_model(network_dir))
return Ensemble(models)
def load_model(species, from_):
models = []
for i in species:
filename = os.path.join(from_, 'ANN-{}.nnf'.format(i))
models.append((i, load_atomic_network(filename)))
return ANIModel(models)
def load_model_ensemble(species, prefix, count):
models = []
for i in range(count):
network_dir = os.path.join('{}{}'.format(prefix, i), 'networks')
models.append(load_model(species, network_dir))
return Ensemble(models)
class Buildins:
def __init__(self):
self.const_file = pkg_resources.resource_filename(
__name__,
'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
self.consts = Constants(self.const_file)
self.aev_computer = AEVComputer(**self.consts)
self.sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
self.energy_shifter = EnergyShifter(self.consts.species,
load_sae(self.sae_file))
self.ensemble_size = 8
self.ensemble_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
self.models = load_model_ensemble(self.consts.species,
self.ensemble_prefix,
self.ensemble_size)
buildins = Buildins()
import torch
from .. import padding
from .. import utils
class Container(torch.nn.Module):
......@@ -18,7 +18,7 @@ class Container(torch.nn.Module):
_, result = model(sc)
results[k].append(result)
results['species'], results['coordinates'] = \
padding.pad_and_batch(species_coordinates)
utils.pad_and_batch(species_coordinates)
for k in self.keys:
results[k] = torch.cat(results[k])
return results
......@@ -5,7 +5,7 @@ from .pyanitools import anidataloader
import torch
import torch.utils.data as data
import pickle
from .. import padding
from .. import utils
def chunk_counts(counts, split):
......@@ -69,7 +69,7 @@ def split_batch(natoms, species, coordinates):
end = start + i
s = species[start:end, ...]
c = coordinates[start:end, ...]
s, c = padding.strip_redundant_padding(s, c)
s, c = utils.strip_redundant_padding(s, c)
species_coordinates.append((s, c))
start = end
return species_coordinates
......@@ -119,7 +119,7 @@ class BatchedANIDataset(Dataset):
for i in properties:
properties[i].append(torch.from_numpy(m[i])
.type(dtype).to(device))
species, coordinates = padding.pad_and_batch(species_coordinates)
species, coordinates = utils.pad_and_batch(species_coordinates)
for i in properties:
properties[i] = torch.cat(properties[i])
......
......@@ -34,3 +34,29 @@ def strip_redundant_padding(species, coordinates):
species = species.index_select(1, non_padding)
coordinates = coordinates.index_select(1, non_padding)
return species, coordinates
class EnergyShifter(torch.nn.Module):
def __init__(self, species, self_energies):
super(EnergyShifter, self).__init__()
self_energies_tensor = [self_energies[s] for s in species]
self.register_buffer('self_energies_tensor',
torch.tensor(self_energies_tensor,
dtype=torch.double))
def sae(self, species):
self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties):
dtype = properties['energies'].dtype
device = properties['energies'].device
properties['energies'] -= self.sae(species).to(dtype).to(device)
return species, coordinates, properties
def forward(self, species_energies):
species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment