Unverified Commit 50fb6cdb authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Modify training related APIs to batch the whole data by padding (#63)

parent b7cab4f1
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -14,7 +14,6 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
aev_computer = torchani.AEVComputer()
self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
......@@ -24,7 +23,6 @@ class TestEnergies(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, energies, _ = pickle.load(f)
species, coordinates = self.prepare((species, coordinates))
_, energies_ = self.model((species, coordinates))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......@@ -36,8 +34,7 @@ class TestEnergies(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, e, _ = pickle.load(f)
species_coordinates.append(
self.prepare((species, coordinates)))
species_coordinates.append((species, coordinates))
energies.append(e)
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
......
import torch
import torchani
import unittest
import random
class TestEnergyShifter(unittest.TestCase):
def setUp(self):
self.tol = 1e-5
self.species = torchani.AEVComputer().species
self.prepare = torchani.PrepareInput(self.species)
self.shift_energy = torchani.EnergyShifter(self.species)
def testSAEMatch(self):
species_coordinates = []
saes = []
for _ in range(10):
k = random.choice(range(5, 30))
species = random.choices(self.species, k=k)
coordinates = torch.empty(1, k, 3)
species_coordinates.append(self.prepare((species, coordinates)))
e = self.shift_energy.sae_from_list(species)
saes.append(e)
species, _ = torchani.padding.pad_and_batch(species_coordinates)
saes_ = self.shift_energy.sae_from_tensor(species)
saes = torch.tensor(saes, dtype=saes_.dtype, device=saes_.device)
self.assertLess((saes - saes_).abs().max(), self.tol)
if __name__ == '__main__':
unittest.main()
......@@ -19,14 +19,13 @@ class TestEnsemble(unittest.TestCase):
n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix
aev = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev.species)
ensemble = torchani.models.NeuroChemNNP(aev.species, ensemble=True)
ensemble = torch.nn.Sequential(prepare, aev, ensemble)
ensemble = torch.nn.Sequential(aev, ensemble)
models = [torchani.models.
NeuroChemNNP(aev.species, ensemble=False,
from_=prefix + '{}/networks/'.format(i))
for i in range(n)]
models = [torch.nn.Sequential(prepare, aev, m) for m in models]
models = [torch.nn.Sequential(aev, m) for m in models]
_, energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
......
......@@ -13,10 +13,8 @@ class TestForce(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-5
aev_computer = torchani.AEVComputer()
self.prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
self.model = torch.nn.Sequential(self.prepare, aev_computer, nnp)
self.prepared2e = torch.nn.Sequential(aev_computer, nnp)
self.model = torch.nn.Sequential(aev_computer, nnp)
def testIsomers(self):
for i in range(N):
......@@ -37,13 +35,12 @@ class TestForce(unittest.TestCase):
datafile = os.path.join(path, 'test_data/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, forces = pickle.load(f)
species, coordinates = self.prepare((species, coordinates))
coordinates = torch.tensor(coordinates, requires_grad=True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.padding.pad_and_batch(
species_coordinates)
_, energies = self.prepared2e((species, coordinates))
_, energies = self.model((species, coordinates))
energies = energies.sum()
for coordinates, forces in coordinates_forces:
derivative = torch.autograd.grad(energies, coordinates,
......
import sys
if sys.version_info.major >= 3:
import os
import unittest
import torch
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator, Events
import torchani
import torchani.data
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
chunksize = 4
threshold = 1e-5
class TestIgnite(unittest.TestCase):
def testIgnite(self):
aev_computer = torchani.AEVComputer()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
ds = torchani.data.ANIDataset(
path, chunksize,
transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0])
loader = torchani.data.dataloader(ds, 1)
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten())
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters())
loss = torchani.ignite.TransformedLoss(
torchani.ignite.MSELoss('energies'),
lambda x: torch.exp(x) - 1)
trainer = create_supervised_trainer(
container, optimizer, loss)
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
})
@trainer.on(Events.COMPLETED)
def completes(trainer):
evaluator.run(loader)
metrics = evaluator.state.metrics
self.assertLess(metrics['RMSE'], threshold)
self.assertLess(trainer.state.output, threshold)
trainer.run(loader, max_epochs=1000)
if __name__ == '__main__':
unittest.main()
import os
import unittest
import torch
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator, Events
import torchani
import torchani.training
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
batchsize = 4
threshold = 1e-5
class TestIgnite(unittest.TestCase):
def testIgnite(self):
aev_computer = torchani.AEVComputer()
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
shift_energy = torchani.EnergyShifter(aev_computer.species)
ds = torchani.training.BatchedANIDataset(
path, aev_computer.species, batchsize,
transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0])
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
model = torch.nn.Sequential(aev_computer, nnp, Flatten())
container = torchani.training.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters())
loss = torchani.training.TransformedLoss(
torchani.training.MSELoss('energies'),
lambda x: torch.exp(x) - 1)
trainer = create_supervised_trainer(
container, optimizer, loss)
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.training.RMSEMetric('energies')
})
@trainer.on(Events.COMPLETED)
def completes(trainer):
evaluator.run(ds)
metrics = evaluator.state.metrics
self.assertLess(metrics['RMSE'], threshold)
self.assertLess(trainer.state.output, threshold)
trainer.run(ds, max_epochs=1000)
if __name__ == '__main__':
unittest.main()
......@@ -8,7 +8,7 @@ import torchani
import pickle
from torchani import buildin_const_file, buildin_sae_file, \
buildin_network_dir
import torchani.pyanitools
import torchani.training.pyanitools
path = os.path.dirname(os.path.realpath(__file__))
conv_au_ev = 27.21138505
......@@ -58,19 +58,24 @@ class NeuroChem (torchani.aev.AEVComputer):
energies, forces
aev = torchani.AEVComputer()
ncaev = NeuroChem().to(torch.device('cpu'))
mol_count = 0
species_indices = {aev.species[i]: i for i in range(len(aev.species))}
for i in [1, 2, 3, 4]:
data_file = os.path.join(
path, '../dataset/ani_gdb_s0{}.h5'.format(i))
adl = torchani.pyanitools.anidataloader(data_file)
adl = torchani.training.pyanitools.anidataloader(data_file)
for data in adl:
coordinates = data['coordinates'][:10, :]
coordinates = torch.from_numpy(coordinates).type(ncaev.EtaR.dtype)
species = data['species']
species = torch.tensor([species_indices[i] for i in data['species']],
dtype=torch.long, device=torch.device('cpu')) \
.expand(10, -1)
smiles = ''.join(data['smiles'])
radial, angular, energies, forces = ncaev(coordinates, species)
radial, angular, energies, forces = ncaev(coordinates, data['species'])
pickleobj = (coordinates, species, radial, angular, energies, forces)
dumpfile = os.path.join(
path, '../tests/test_data/{}'.format(mol_count))
......
from .energyshifter import EnergyShifter
from . import models
from . import data
from . import ignite
from . import training
from . import padding
from .aev import AEVComputer, PrepareInput
from .aev import AEVComputer
from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, buildin_ensemble
__all__ = ['PrepareInput', 'AEVComputer', 'EnergyShifter',
'models', 'data', 'padding', 'ignite',
'models', 'training', 'padding', 'ignite',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir',
'buildin_model_prefix', 'buildin_ensemble']
......@@ -107,40 +107,6 @@ class AEVComputerBase(BenchmarkedModule):
raise NotImplementedError('subclass must override this method')
class PrepareInput(torch.nn.Module):
def __init__(self, species):
super(PrepareInput, self).__init__()
self.species = species
def species_to_tensor(self, species, device):
"""Convert species list into a long tensor.
Parameters
----------
species : list
List of string for the species of each atoms.
device : torch.device
The device to store tensor
Returns
-------
torch.Tensor
Long tensor for the species, where a value k means the species is
the same as self.species[k].
"""
indices = {self.species[i]: i for i in range(len(self.species))}
values = [indices[i] for i in species]
return torch.tensor(values, dtype=torch.long, device=device)
def forward(self, species_coordinates):
species, coordinates = species_coordinates
conformations = coordinates.shape[0]
species = self.species_to_tensor(species, coordinates.device)
species = species.expand(conformations, -1)
return species, coordinates
def _cutoff_cosine(distances, cutoff):
"""Compute the elementwise cutoff cosine function
......@@ -319,8 +285,11 @@ class AEVComputer(AEVComputerBase):
"""Shape (conformations, atoms, atoms) storing Rij distances"""
padding_mask = (species == -1).unsqueeze(1)
distances = torch.where(padding_mask, torch.tensor(math.inf),
distances)
distances = torch.where(
padding_mask,
torch.tensor(math.inf, dtype=self.EtaR.dtype,
device=self.EtaR.device),
distances)
distances, indices = distances.sort(-1)
......
......@@ -22,21 +22,18 @@ class EnergyShifter(torch.nn.Module):
torch.tensor(self_energies_tensor,
dtype=torch.double))
def sae_from_list(self, species):
energies = [self.self_energies[i] for i in species]
return sum(energies)
def sae_from_tensor(self, species):
def sae(self, species):
self_energies = self.self_energies_tensor[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, data):
sae = self.sae_from_list(data['species'])
data['energies'] -= sae
return data
def subtract_from_dataset(self, species, coordinates, properties):
dtype = properties['energies'].dtype
device = properties['energies'].device
properties['energies'] -= self.sae(species).to(dtype).to(device)
return species, coordinates, properties
def forward(self, species_energies):
species, energies = species_energies
sae = self.sae_from_tensor(species).to(energies.dtype)
sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae
......@@ -27,3 +27,10 @@ def present_species(species):
if present_species[0].item() == -1:
present_species = present_species[1:]
return present_species
def strip_redundant_padding(species, coordinates):
non_padding = (species >= 0).any(dim=0)
species = species.masked_select(non_padding, dim=1)
coordinates = coordinates.masked_select(non_padding, dim=1)
return species, coordinates
from .container import Container
from .data import BatchedANIDataset, load_or_create
from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric, \
TransformedLoss
from . import pyanitools
__all__ = ['Container', 'DictLoss', 'DictMetric', 'MSELoss', 'RMSEMetric',
'TransformedLoss']
__all__ = ['Container', 'BatchedANIDataset', 'load_or_create', 'DictLoss',
'DictMetric', 'MSELoss', 'RMSEMetric', 'TransformedLoss',
'pyanitools']
import torch
from ..data import collate
class Container(torch.nn.Module):
......@@ -10,13 +9,13 @@ class Container(torch.nn.Module):
for i in models:
setattr(self, 'model_' + i, models[i])
def forward(self, batch):
all_results = []
for i in zip(batch['species'], batch['coordinates']):
results = {}
for k in self.keys:
model = getattr(self, 'model_' + k)
_, results[k] = model(i)
all_results.append(results)
batch.update(collate(all_results))
return batch
def forward(self, species_coordinates):
species, coordinates = species_coordinates
results = {
'species': species,
'coordinates': coordinates,
}
for k in self.keys:
model = getattr(self, 'model_' + k)
_, results[k] = model((species, coordinates))
return results
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Dataset
from os.path import join, isfile, isdir
import os
from .pyanitools import anidataloader
import torch
import torch.utils.data as data
import pickle
import collections.abc
from .. import padding
class ANIDataset(Dataset):
class BatchedANIDataset(Dataset):
def __init__(self, path, chunk_size, shuffle=True, properties=['energies'],
transform=(), dtype=torch.get_default_dtype(),
device=torch.device('cpu')):
super(ANIDataset, self).__init__()
def __init__(self, path, species, batch_size, shuffle=True,
properties=['energies'], transform=(),
dtype=torch.get_default_dtype(), device=torch.device('cpu')):
super(BatchedANIDataset, self).__init__()
self.path = path
self.chunks_size = chunk_size
self.species = species
self.species_indices = {
self.species[i]: i for i in range(len(self.species))}
self.batch_size = batch_size
self.shuffle = shuffle
self.properties = properties
self.dtype = dtype
......@@ -33,51 +36,68 @@ class ANIDataset(Dataset):
else:
raise ValueError('Bad path')
# generate chunks
chunks = []
# load full dataset
species_coordinates = []
properties = {k: [] for k in self.properties}
for f in files:
for m in anidataloader(f):
full = {
'coordinates': torch.from_numpy(m['coordinates'])
.type(dtype).to(device)
}
conformations = full['coordinates'].shape[0]
for i in properties:
full[i] = torch.from_numpy(m[i]).type(dtype).to(device)
species = m['species']
if shuffle:
indices = torch.randperm(conformations, device=device)
else:
indices = torch.arange(conformations, dtype=torch.int64,
device=device)
num_chunks = (conformations + chunk_size - 1) // chunk_size
for i in range(num_chunks):
chunk_start = i * chunk_size
chunk_end = min(chunk_start + chunk_size, conformations)
chunk_indices = indices[chunk_start:chunk_end]
chunk = {}
for j in full:
chunk[j] = full[j].index_select(0, chunk_indices)
chunk['species'] = species
for t in transform:
chunk = t(chunk)
chunks.append(chunk)
self.chunks = chunks
indices = [self.species_indices[i] for i in species]
species = torch.tensor(indices, dtype=torch.long,
device=device)
coordinates = torch.from_numpy(m['coordinates']) \
.type(dtype).to(device)
species_coordinates.append((species, coordinates))
for i in properties:
properties[i].append(torch.from_numpy(m[i])
.type(dtype).to(device))
species, coordinates = padding.pad_and_batch(species_coordinates)
for i in properties:
properties[i] = torch.cat(properties[i])
# shuffle if required
conformations = coordinates.shape[0]
if shuffle:
indices = torch.randperm(conformations, device=device)
species = species.index_select(0, indices)
coordinates = coordinates.index_select(0, indices)
for i in properties:
properties[i] = properties[i].index_select(0, indices)
# do transformations on data
for t in transform:
species, coordinates, properties = t(species, coordinates,
properties)
# split into minibatches, and strip reduncant padding
batches = []
num_batches = (conformations + batch_size - 1) // batch_size
for i in range(num_batches):
start = i * batch_size
end = min((i + 1) * batch_size, conformations)
species_batch = species[start:end, ...]
coordinates_batch = coordinates[start:end, ...]
properties_batch = {
k: properties[k][start:end, ...] for k in properties
}
batches.append(((species_batch, coordinates_batch),
properties_batch))
self.batches = batches
def __getitem__(self, idx):
chunk = self.chunks[idx]
input_chunk = {k: chunk[k] for k in ('coordinates', 'species')}
return input_chunk, chunk
return self.batches[idx]
def __len__(self):
return len(self.chunks)
return len(self.batches)
def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
def load_or_create(checkpoint, batch_size, species, dataset_path,
*args, **kwargs):
"""Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset"""
if not os.path.isfile(checkpoint):
full_dataset = ANIDataset(dataset_path, chunk_size, *args, **kwargs)
full_dataset = BatchedANIDataset(dataset_path, species, batch_size,
*args, **kwargs)
training_size = int(len(full_dataset) * 0.8)
validation_size = int(len(full_dataset) * 0.1)
testing_size = len(full_dataset) - training_size - validation_size
......@@ -90,23 +110,3 @@ def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
with open(checkpoint, 'rb') as f:
training, validation, testing = pickle.load(f)
return training, validation, testing
def collate(batch):
no_collate = ['coordinates', 'species']
if isinstance(batch[0], torch.Tensor):
return torch.cat(batch)
elif isinstance(batch[0], collections.abc.Mapping):
return {key: ((lambda x: x) if key in no_collate else collate)
([d[key] for d in batch])
for key in batch[0]}
elif isinstance(batch[0], collections.abc.Sequence):
transposed = zip(*batch)
return [collate(samples) for samples in transposed]
else:
raise ValueError('Unexpected element type: {}'.format(type(batch[0])))
def dataloader(dataset, batch_chunks, shuffle=True, **kwargs):
return DataLoader(dataset, batch_chunks, shuffle,
collate_fn=collate, **kwargs)
......@@ -17,17 +17,10 @@ class DictLoss(_Loss):
class _PerAtomDictLoss(DictLoss):
@staticmethod
def num_atoms(input):
ret = []
for s, c in zip(input['species'], input['coordinates']):
ret.append(torch.full((c.shape[0],), len(s),
dtype=c.dtype, device=c.device))
return torch.cat(ret)
def forward(self, input, other):
loss = self.loss(input[self.key], other[self.key])
loss /= self.num_atoms(input)
num_atoms = (input['species'] >= 0).sum(dim=1)
loss /= num_atoms.to(loss.dtype).to(loss.device)
n = loss.numel()
return loss.sum() / n
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment