Unverified Commit 84439caf authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

merge Container and BatchModel to provide richer information in output (#50)

parent 1d8bba37
......@@ -23,7 +23,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H']
energy = model((species, coordinates))
_, energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
energy = shift_energy.add_sae(energy, species)
force = -derivative
......
......@@ -34,7 +34,7 @@ def get_or_create_model(filename, benchmark=False,
class Flatten(torch.nn.Module):
def forward(self, x):
return x.flatten()
return x[0], x[1].flatten()
model = torch.nn.Sequential(prepare, aev_computer, model, Flatten())
if os.path.isfile(filename):
......
......@@ -55,17 +55,16 @@ training, validation, testing = torchani.data.load_or_create(
training = torchani.data.dataloader(training, parser.batch_chunks)
validation = torchani.data.dataloader(validation, parser.batch_chunks)
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
container = torchani.ignite.Container({'energies': nnp})
parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args)
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.energy_rmse_metric
'RMSE': torchani.ignite.RMSEMetric('energies')
})
......
......@@ -30,12 +30,11 @@ dataset = torchani.data.ANIDataset(
transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
......
import sys
if sys.version_info.major >= 3:
import os
import unittest
import torch
import torchani
import torchani.data
import itertools
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset')
chunksize = 32
batch_chunks = 32
class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
model = torch.nn.Sequential(prepare, aev_computer, nnp)
batch_nnp = torchani.models.BatchModel(model)
for batch_input, batch_output in itertools.islice(loader, 10):
batch_output_ = batch_nnp(batch_input).squeeze()
self.assertListEqual(list(batch_output_.shape),
list(batch_output['energies'].shape))
if __name__ == '__main__':
unittest.main()
......@@ -12,7 +12,7 @@ if sys.version_info.major >= 3:
def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize)
for i in ds:
for i, _ in ds:
self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def testChunk64(self):
......
......@@ -19,9 +19,9 @@ class TestEnergies(unittest.TestCase):
self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, energies):
shift_energy = torchani.EnergyShifter(torchani.buildin_sae_file)
energies_ = self.model((species, coordinates)).squeeze()
energies_ = shift_energy.add_sae(energies_, species)
shift_energy = torchani.EnergyShifter()
_, energies_ = self.model((species, coordinates))
energies_ = shift_energy.add_sae(energies_.squeeze(), species)
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......
......@@ -28,9 +28,9 @@ class TestEnsemble(unittest.TestCase):
for i in range(n)]
models = [torch.nn.Sequential(prepare, aev, m) for m in models]
energy1 = ensemble((species, coordinates))
_, energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m((species, coordinates)) for m in models]
energy2 = [m((species, coordinates))[1] for m in models]
energy2 = sum(energy2) / n
force2 = torch.autograd.grad(energy2.sum(), coordinates)[0]
energy_diff = (energy1 - energy2).abs().max().item()
......
......@@ -19,7 +19,7 @@ class TestForce(unittest.TestCase):
def _test_molecule(self, coordinates, species, forces):
coordinates = torch.tensor(coordinates, requires_grad=True)
energies = self.model((species, coordinates))
_, energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(), coordinates)[0]
max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......
......@@ -28,16 +28,15 @@ if sys.version_info.major >= 3:
class Flatten(torch.nn.Module):
def forward(self, x):
return x.flatten()
return x[0], x[1].flatten()
model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten())
batch_nnp = torchani.models.BatchModel(model)
container = torchani.ignite.Container({'energies': batch_nnp})
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters())
trainer = create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.energy_rmse_metric
'RMSE': torchani.ignite.RMSEMetric('energies')
})
@trainer.on(Events.COMPLETED)
......
......@@ -5,6 +5,7 @@ from .pyanitools import anidataloader
import torch
import torch.utils.data as data
import pickle
import collections
class ANIDataset(Dataset):
......@@ -64,7 +65,9 @@ class ANIDataset(Dataset):
self.chunks = chunks
def __getitem__(self, idx):
return self.chunks[idx]
chunk = self.chunks[idx]
input_chunk = {k: chunk[k] for k in ('coordinates', 'species')}
return input_chunk, chunk
def __len__(self):
return len(self.chunks)
......@@ -89,22 +92,21 @@ def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
return training, validation, testing
def _collate(batch):
input_keys = ['coordinates', 'species']
inputs = [{k: i[k] for k in input_keys} for i in batch]
outputs = {}
for i in batch:
for j in i:
if j in input_keys:
continue
if j not in outputs:
outputs[j] = []
outputs[j].append(i[j])
for i in outputs:
outputs[i] = torch.cat(outputs[i])
return inputs, outputs
def collate(batch):
no_collate = ['coordinates', 'species']
if isinstance(batch[0], torch.Tensor):
return torch.cat(batch)
elif isinstance(batch[0], collections.Mapping):
return {key: ((lambda x: x) if key in no_collate else collate)
([d[key] for d in batch])
for key in batch[0]}
elif isinstance(batch[0], collections.Sequence):
transposed = zip(*batch)
return [collate(samples) for samples in transposed]
else:
raise ValueError('Unexpected element type: {}'.format(type(batch[0])))
def dataloader(dataset, batch_chunks, shuffle=True, **kwargs):
return DataLoader(dataset, batch_chunks, shuffle,
collate_fn=_collate, **kwargs)
collate_fn=collate, **kwargs)
from .container import Container
from .loss_metrics import DictLoss, DictMetric, energy_mse_loss, \
energy_rmse_metric
from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric
__all__ = ['Container', 'DictLoss', 'DictMetric', 'energy_mse_loss',
'energy_rmse_metric']
__all__ = ['Container', 'DictLoss', 'DictMetric', 'MSELoss', 'RMSEMetric']
import torch
from ..models import BatchModel
from ..data import collate
class Container(torch.nn.Module):
......@@ -8,13 +8,15 @@ class Container(torch.nn.Module):
super(Container, self).__init__()
self.keys = models.keys()
for i in models:
if not isinstance(models[i], BatchModel):
raise ValueError('Container must contain batch models')
setattr(self, 'model_' + i, models[i])
def forward(self, batch):
output = {}
for i in self.keys:
model = getattr(self, 'model_' + i)
output[i] = model(batch)
return output
all_results = []
for i in zip(batch['species'], batch['coordinates']):
results = {}
for k in self.keys:
model = getattr(self, 'model_' + k)
_, results[k] = model(i)
all_results.append(results)
batch.update(collate(all_results))
return batch
......@@ -4,6 +4,14 @@ from ignite.metrics import RootMeanSquaredError
import torch
def num_atoms(input):
ret = []
for s, c in zip(input['species'], input['coordinates']):
ret.append(torch.full((c.shape[0],), len(s),
dtype=c.dtype, device=c.device))
return torch.cat(ret)
class DictLoss(_Loss):
def __init__(self, key, loss):
......@@ -33,5 +41,9 @@ class DictMetric(Metric):
return self.metric.compute()
energy_mse_loss = DictLoss('energies', torch.nn.MSELoss())
energy_rmse_metric = DictMetric('energies', RootMeanSquaredError())
def MSELoss(key):
return DictLoss(key, torch.nn.MSELoss())
def RMSEMetric(key):
return DictMetric(key, RootMeanSquaredError())
from .custom import CustomModel
from .neurochem_nnp import NeuroChemNNP
from .batch import BatchModel
__all__ = ['CustomModel', 'NeuroChemNNP', 'BatchModel']
__all__ = ['CustomModel', 'NeuroChemNNP']
......@@ -85,4 +85,4 @@ class ANIModel(BenchmarkedModule):
per_species_outputs = torch.cat(per_species_outputs, dim=1)
molecule_output = self.reducer(per_species_outputs, dim=1)
return molecule_output
return species, molecule_output
import torch
class BatchModel(torch.nn.Module):
def __init__(self, model):
super(BatchModel, self).__init__()
self.model = model
def forward(self, batch):
results = []
for i in batch:
coordinates = i['coordinates']
species = i['species']
results.append(self.model((species, coordinates)))
return torch.cat(results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment