Unverified Commit 84439caf authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

merge Container and BatchModel to provide richer information in output (#50)

parent 1d8bba37
...@@ -23,7 +23,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], ...@@ -23,7 +23,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
requires_grad=True) requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H'] species = ['C', 'H', 'H', 'H', 'H']
energy = model((species, coordinates)) _, energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0] derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
energy = shift_energy.add_sae(energy, species) energy = shift_energy.add_sae(energy, species)
force = -derivative force = -derivative
......
...@@ -34,7 +34,7 @@ def get_or_create_model(filename, benchmark=False, ...@@ -34,7 +34,7 @@ def get_or_create_model(filename, benchmark=False,
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x.flatten() return x[0], x[1].flatten()
model = torch.nn.Sequential(prepare, aev_computer, model, Flatten()) model = torch.nn.Sequential(prepare, aev_computer, model, Flatten())
if os.path.isfile(filename): if os.path.isfile(filename):
......
...@@ -55,17 +55,16 @@ training, validation, testing = torchani.data.load_or_create( ...@@ -55,17 +55,16 @@ training, validation, testing = torchani.data.load_or_create(
training = torchani.data.dataloader(training, parser.batch_chunks) training = torchani.data.dataloader(training, parser.batch_chunks)
validation = torchani.data.dataloader(validation, parser.batch_chunks) validation = torchani.data.dataloader(validation, parser.batch_chunks)
nnp = model.get_or_create_model(parser.model_checkpoint, device=device) nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
batch_nnp = torchani.models.BatchModel(nnp) container = torchani.ignite.Container({'energies': nnp})
container = torchani.ignite.Container({'energies': batch_nnp})
parser.optim_args = json.loads(parser.optim_args) parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer) optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args) optimizer = optimizer(nnp.parameters(), **parser.optim_args)
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss) container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={ evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.energy_rmse_metric 'RMSE': torchani.ignite.RMSEMetric('energies')
}) })
......
...@@ -30,12 +30,11 @@ dataset = torchani.data.ANIDataset( ...@@ -30,12 +30,11 @@ dataset = torchani.data.ANIDataset(
transform=[shift_energy.dataset_subtract_sae]) transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks) dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True, device=device) nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
batch_nnp = torchani.models.BatchModel(nnp) container = torchani.ignite.Container({'energies': nnp})
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters()) optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss) container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED) @trainer.on(ignite.engine.Events.EPOCH_STARTED)
......
import sys
if sys.version_info.major >= 3:
import os
import unittest
import torch
import torchani
import torchani.data
import itertools
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset')
chunksize = 32
batch_chunks = 32
class TestBatch(unittest.TestCase):
def testBatchLoadAndInference(self):
ds = torchani.data.ANIDataset(path, chunksize)
loader = torchani.data.dataloader(ds, batch_chunks)
aev_computer = torchani.SortedAEV()
prepare = torchani.PrepareInput(aev_computer.species)
nnp = torchani.models.NeuroChemNNP(aev_computer.species)
model = torch.nn.Sequential(prepare, aev_computer, nnp)
batch_nnp = torchani.models.BatchModel(model)
for batch_input, batch_output in itertools.islice(loader, 10):
batch_output_ = batch_nnp(batch_input).squeeze()
self.assertListEqual(list(batch_output_.shape),
list(batch_output['energies'].shape))
if __name__ == '__main__':
unittest.main()
...@@ -12,7 +12,7 @@ if sys.version_info.major >= 3: ...@@ -12,7 +12,7 @@ if sys.version_info.major >= 3:
def _test_chunksize(self, chunksize): def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize) ds = torchani.data.ANIDataset(path, chunksize)
for i in ds: for i, _ in ds:
self.assertLessEqual(i['coordinates'].shape[0], chunksize) self.assertLessEqual(i['coordinates'].shape[0], chunksize)
def testChunk64(self): def testChunk64(self):
......
...@@ -19,9 +19,9 @@ class TestEnergies(unittest.TestCase): ...@@ -19,9 +19,9 @@ class TestEnergies(unittest.TestCase):
self.model = torch.nn.Sequential(prepare, aev_computer, nnp) self.model = torch.nn.Sequential(prepare, aev_computer, nnp)
def _test_molecule(self, coordinates, species, energies): def _test_molecule(self, coordinates, species, energies):
shift_energy = torchani.EnergyShifter(torchani.buildin_sae_file) shift_energy = torchani.EnergyShifter()
energies_ = self.model((species, coordinates)).squeeze() _, energies_ = self.model((species, coordinates))
energies_ = shift_energy.add_sae(energies_, species) energies_ = shift_energy.add_sae(energies_.squeeze(), species)
max_diff = (energies - energies_).abs().max().item() max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance) self.assertLess(max_diff, self.tolerance)
......
...@@ -28,9 +28,9 @@ class TestEnsemble(unittest.TestCase): ...@@ -28,9 +28,9 @@ class TestEnsemble(unittest.TestCase):
for i in range(n)] for i in range(n)]
models = [torch.nn.Sequential(prepare, aev, m) for m in models] models = [torch.nn.Sequential(prepare, aev, m) for m in models]
energy1 = ensemble((species, coordinates)) _, energy1 = ensemble((species, coordinates))
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0] force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m((species, coordinates)) for m in models] energy2 = [m((species, coordinates))[1] for m in models]
energy2 = sum(energy2) / n energy2 = sum(energy2) / n
force2 = torch.autograd.grad(energy2.sum(), coordinates)[0] force2 = torch.autograd.grad(energy2.sum(), coordinates)[0]
energy_diff = (energy1 - energy2).abs().max().item() energy_diff = (energy1 - energy2).abs().max().item()
......
...@@ -19,7 +19,7 @@ class TestForce(unittest.TestCase): ...@@ -19,7 +19,7 @@ class TestForce(unittest.TestCase):
def _test_molecule(self, coordinates, species, forces): def _test_molecule(self, coordinates, species, forces):
coordinates = torch.tensor(coordinates, requires_grad=True) coordinates = torch.tensor(coordinates, requires_grad=True)
energies = self.model((species, coordinates)) _, energies = self.model((species, coordinates))
derivative = torch.autograd.grad(energies.sum(), coordinates)[0] derivative = torch.autograd.grad(energies.sum(), coordinates)[0]
max_diff = (forces + derivative).abs().max().item() max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance) self.assertLess(max_diff, self.tolerance)
......
...@@ -28,16 +28,15 @@ if sys.version_info.major >= 3: ...@@ -28,16 +28,15 @@ if sys.version_info.major >= 3:
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
def forward(self, x): def forward(self, x):
return x.flatten() return x[0], x[1].flatten()
model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten()) model = torch.nn.Sequential(prepare, aev_computer, nnp, Flatten())
batch_nnp = torchani.models.BatchModel(model) container = torchani.ignite.Container({'energies': model})
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(container.parameters()) optimizer = torch.optim.Adam(container.parameters())
trainer = create_supervised_trainer( trainer = create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss) container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = create_supervised_evaluator(container, metrics={ evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.ignite.energy_rmse_metric 'RMSE': torchani.ignite.RMSEMetric('energies')
}) })
@trainer.on(Events.COMPLETED) @trainer.on(Events.COMPLETED)
......
...@@ -5,6 +5,7 @@ from .pyanitools import anidataloader ...@@ -5,6 +5,7 @@ from .pyanitools import anidataloader
import torch import torch
import torch.utils.data as data import torch.utils.data as data
import pickle import pickle
import collections
class ANIDataset(Dataset): class ANIDataset(Dataset):
...@@ -64,7 +65,9 @@ class ANIDataset(Dataset): ...@@ -64,7 +65,9 @@ class ANIDataset(Dataset):
self.chunks = chunks self.chunks = chunks
def __getitem__(self, idx): def __getitem__(self, idx):
return self.chunks[idx] chunk = self.chunks[idx]
input_chunk = {k: chunk[k] for k in ('coordinates', 'species')}
return input_chunk, chunk
def __len__(self): def __len__(self):
return len(self.chunks) return len(self.chunks)
...@@ -89,22 +92,21 @@ def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs): ...@@ -89,22 +92,21 @@ def load_or_create(checkpoint, dataset_path, chunk_size, *args, **kwargs):
return training, validation, testing return training, validation, testing
def _collate(batch): def collate(batch):
input_keys = ['coordinates', 'species'] no_collate = ['coordinates', 'species']
inputs = [{k: i[k] for k in input_keys} for i in batch] if isinstance(batch[0], torch.Tensor):
outputs = {} return torch.cat(batch)
for i in batch: elif isinstance(batch[0], collections.Mapping):
for j in i: return {key: ((lambda x: x) if key in no_collate else collate)
if j in input_keys: ([d[key] for d in batch])
continue for key in batch[0]}
if j not in outputs: elif isinstance(batch[0], collections.Sequence):
outputs[j] = [] transposed = zip(*batch)
outputs[j].append(i[j]) return [collate(samples) for samples in transposed]
for i in outputs: else:
outputs[i] = torch.cat(outputs[i]) raise ValueError('Unexpected element type: {}'.format(type(batch[0])))
return inputs, outputs
def dataloader(dataset, batch_chunks, shuffle=True, **kwargs): def dataloader(dataset, batch_chunks, shuffle=True, **kwargs):
return DataLoader(dataset, batch_chunks, shuffle, return DataLoader(dataset, batch_chunks, shuffle,
collate_fn=_collate, **kwargs) collate_fn=collate, **kwargs)
from .container import Container from .container import Container
from .loss_metrics import DictLoss, DictMetric, energy_mse_loss, \ from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric
energy_rmse_metric
__all__ = ['Container', 'DictLoss', 'DictMetric', 'energy_mse_loss', __all__ = ['Container', 'DictLoss', 'DictMetric', 'MSELoss', 'RMSEMetric']
'energy_rmse_metric']
import torch import torch
from ..models import BatchModel from ..data import collate
class Container(torch.nn.Module): class Container(torch.nn.Module):
...@@ -8,13 +8,15 @@ class Container(torch.nn.Module): ...@@ -8,13 +8,15 @@ class Container(torch.nn.Module):
super(Container, self).__init__() super(Container, self).__init__()
self.keys = models.keys() self.keys = models.keys()
for i in models: for i in models:
if not isinstance(models[i], BatchModel):
raise ValueError('Container must contain batch models')
setattr(self, 'model_' + i, models[i]) setattr(self, 'model_' + i, models[i])
def forward(self, batch): def forward(self, batch):
output = {} all_results = []
for i in self.keys: for i in zip(batch['species'], batch['coordinates']):
model = getattr(self, 'model_' + i) results = {}
output[i] = model(batch) for k in self.keys:
return output model = getattr(self, 'model_' + k)
_, results[k] = model(i)
all_results.append(results)
batch.update(collate(all_results))
return batch
...@@ -4,6 +4,14 @@ from ignite.metrics import RootMeanSquaredError ...@@ -4,6 +4,14 @@ from ignite.metrics import RootMeanSquaredError
import torch import torch
def num_atoms(input):
ret = []
for s, c in zip(input['species'], input['coordinates']):
ret.append(torch.full((c.shape[0],), len(s),
dtype=c.dtype, device=c.device))
return torch.cat(ret)
class DictLoss(_Loss): class DictLoss(_Loss):
def __init__(self, key, loss): def __init__(self, key, loss):
...@@ -33,5 +41,9 @@ class DictMetric(Metric): ...@@ -33,5 +41,9 @@ class DictMetric(Metric):
return self.metric.compute() return self.metric.compute()
energy_mse_loss = DictLoss('energies', torch.nn.MSELoss()) def MSELoss(key):
energy_rmse_metric = DictMetric('energies', RootMeanSquaredError()) return DictLoss(key, torch.nn.MSELoss())
def RMSEMetric(key):
return DictMetric(key, RootMeanSquaredError())
from .custom import CustomModel from .custom import CustomModel
from .neurochem_nnp import NeuroChemNNP from .neurochem_nnp import NeuroChemNNP
from .batch import BatchModel
__all__ = ['CustomModel', 'NeuroChemNNP', 'BatchModel'] __all__ = ['CustomModel', 'NeuroChemNNP']
...@@ -85,4 +85,4 @@ class ANIModel(BenchmarkedModule): ...@@ -85,4 +85,4 @@ class ANIModel(BenchmarkedModule):
per_species_outputs = torch.cat(per_species_outputs, dim=1) per_species_outputs = torch.cat(per_species_outputs, dim=1)
molecule_output = self.reducer(per_species_outputs, dim=1) molecule_output = self.reducer(per_species_outputs, dim=1)
return molecule_output return species, molecule_output
import torch
class BatchModel(torch.nn.Module):
def __init__(self, model):
super(BatchModel, self).__init__()
self.model = model
def forward(self, batch):
results = []
for i in batch:
coordinates = i['coordinates']
species = i['species']
results.append(self.model((species, coordinates)))
return torch.cat(results)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment