Unverified Commit 3cced1e6 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Improve training related API (#76)

parent d7ef8182
......@@ -24,9 +24,9 @@ steps:
Examples:
image: '${{BuildTorchANI}}'
commands:
- rm -rf *.dat *.pt
- python examples/nnp_training.py ./dataset/ani_gdb_s01.h5
- python examples/nnp_training.py ./dataset/ani_gdb_s01.h5 # run twice to test if checkpoint is working
- rm -rf *.pt
- python examples/nnp_training.py dataset/ani_gdb_s01.h5 dataset/ani_gdb_s01.h5
- python examples/nnp_training.py dataset/ani_gdb_s01.h5 dataset/ani_gdb_s01.h5 # run twice to test if checkpoint is working
- python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5
- python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5 # run twice to test if checkpoint is working
- python examples/energy_force.py
......
......@@ -22,7 +22,7 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True)
species = consts.species_to_tensor('CHHHH', device).unsqueeze(0)
species = consts.species_to_tensor('CHHHH').to(device).unsqueeze(0)
_, energy = model((species, coordinates))
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
......
......@@ -35,15 +35,15 @@ shift_energy = torchani.neurochem.load_sae(parser.sae_file)
aev_computer = torchani.AEVComputer(**consts)
nn = torchani.neurochem.load_model(consts.species, parser.network_dir)
model = torch.nn.Sequential(aev_computer, nn)
container = torchani.training.Container({'energies': model})
container = torchani.ignite.Container({'energies': model})
container = container.to(device)
# load datasets
if parser.dataset_path.endswith('.h5') or \
parser.dataset_path.endswith('.hdf5') or \
os.path.isdir(parser.dataset_path):
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, consts.species, parser.batch_size,
dataset = torchani.data.BatchedANIDataset(
parser.dataset_path, consts.species_to_tensor, parser.batch_size,
device=device, transform=[shift_energy.subtract_from_dataset])
datasets = [dataset]
else:
......@@ -60,7 +60,7 @@ def hartree2kcal(x):
for dataset in datasets:
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.training.RMSEMetric('energies')
'RMSE': torchani.ignite.RMSEMetric('energies')
})
evaluator.run(dataset)
metrics = evaluator.state.metrics
......
......@@ -11,18 +11,18 @@ import json
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
help='Path of the dataset, can a hdf5 file \
parser.add_argument('training_path',
help='Path of the training set, can be a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('validation_path',
help='Path of the validation set, can be a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('--dataset_checkpoint',
help='Checkpoint file for datasets',
default='dataset-checkpoint.dat')
parser.add_argument('--model_checkpoint',
help='Checkpoint file for model',
default='model.pt')
parser.add_argument('-m', '--max_epochs',
help='Maximum number of epoches',
default=100, type=int)
default=300, type=int)
parser.add_argument('--training_rmse_every',
help='Compute training RMSE every epoches',
default=20, type=int)
......@@ -53,20 +53,24 @@ start = timeit.default_timer()
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
shift_energy = torchani.buildins.energy_shifter
training, validation, testing = torchani.training.load_or_create(
parser.dataset_checkpoint, parser.batch_size, model.consts.species,
parser.dataset_path, device=device,
training = torchani.data.BatchedANIDataset(
parser.training_path, model.consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset(
parser.validation_path, model.consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
container = torchani.training.Container({'energies': nnp})
container = torchani.ignite.Container({'energies': nnp})
parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args)
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.training.MSELoss('energies'))
container, optimizer, torchani.ignite.MSELoss('energies'))
evaluator = ignite.engine.create_supervised_evaluator(container, metrics={
'RMSE': torchani.training.RMSEMetric('energies')
'RMSE': torchani.ignite.RMSEMetric('energies')
})
......@@ -97,19 +101,25 @@ def finalize_tqdm(trainer):
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def validation_and_checkpoint(trainer):
def evaluate(dataset, name):
evaluator = ignite.engine.create_supervised_evaluator(
container,
metrics={
'RMSE': torchani.ignite.RMSEMetric('energies')
}
)
evaluator.run(dataset)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar(name, rmse, trainer.state.epoch)
return rmse
# compute validation RMSE
evaluator.run(validation)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar('validation_rmse_vs_epoch', rmse, trainer.state.epoch)
rmse = evaluate(validation, 'validation_rmse_vs_epoch')
# compute training RMSE
if trainer.state.epoch % parser.training_rmse_every == 0:
evaluator.run(training)
metrics = evaluator.state.metrics
rmse = hartree2kcal(metrics['RMSE'])
writer.add_scalar('training_rmse_vs_epoch', rmse,
trainer.state.epoch)
if trainer.state.epoch % parser.training_rmse_every == 1:
evaluate(training, 'training_rmse_vs_epoch')
# handle best validation RMSE
if rmse < trainer.state.best_validation_rmse:
......@@ -120,9 +130,12 @@ def validation_and_checkpoint(trainer):
torch.save(nnp.state_dict(), parser.model_checkpoint)
else:
trainer.state.no_improve_count += 1
writer.add_scalar('no_improve_count_vs_epoch',
trainer.state.no_improve_count,
trainer.state.epoch)
if trainer.state.no_improve_count > parser.early_stopping:
trainer.terminate()
trainer.terminate()
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
......@@ -134,8 +147,7 @@ def log_time(trainer):
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def log_loss_and_time(trainer):
iteration = trainer.state.iteration
rmse = hartree2kcal(math.sqrt(trainer.state.output))
writer.add_scalar('training_atomic_rmse_vs_iteration', rmse, iteration)
writer.add_scalar('loss_vs_iteration', trainer.state.output, iteration)
trainer.run(training, max_epochs=parser.max_epochs)
......@@ -23,15 +23,15 @@ parser = parser.parse_args()
device = torch.device(parser.device)
nnp = model.get_or_create_model('/tmp/model.pt', device=device)
shift_energy = torchani.buildins.energy_shifter
dataset = torchani.training.BatchedANIDataset(
parser.dataset_path, model.consts.species,
dataset = torchani.data.BatchedANIDataset(
parser.dataset_path, model.consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
container = torchani.training.Container({'energies': nnp})
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.training.MSELoss('energies'))
container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
......
......@@ -12,9 +12,9 @@ consts = torchani.buildins.consts
class TestData(unittest.TestCase):
def setUp(self):
self.ds = torchani.training.BatchedANIDataset(dataset_path,
consts.species,
batch_size)
self.ds = torchani.data.BatchedANIDataset(dataset_path,
consts.species_to_tensor,
batch_size)
def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1-t2).abs().max(), 0)
......@@ -32,8 +32,7 @@ class TestData(unittest.TestCase):
(species3, coordinates3),
])
natoms = (species >= 0).to(torch.long).sum(1)
chunks = torchani.training.data.split_batch(natoms, species,
coordinates)
chunks = torchani.data.split_batch(natoms, species, coordinates)
start = 0
last = None
for s, c in chunks:
......
......@@ -5,7 +5,7 @@ import copy
from ignite.engine import create_supervised_trainer, \
create_supervised_evaluator, Events
import torchani
import torchani.training
import torchani.ignite
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/ani_gdb_s01.h5')
......@@ -19,8 +19,8 @@ class TestIgnite(unittest.TestCase):
aev_computer = torchani.buildins.aev_computer
nnp = copy.deepcopy(torchani.buildins.models[0])
shift_energy = torchani.buildins.energy_shifter
ds = torchani.training.BatchedANIDataset(
path, torchani.buildins.consts.species, batchsize,
ds = torchani.data.BatchedANIDataset(
path, torchani.buildins.consts.species_to_tensor, batchsize,
transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0])
......@@ -29,15 +29,15 @@ class TestIgnite(unittest.TestCase):
return x[0], x[1].flatten()
model = torch.nn.Sequential(aev_computer, nnp, Flatten())
container = torchani.training.Container({'energies': model})
container = torchani.ignite.Container({'energies': model})
optimizer = torch.optim.Adam(container.parameters())
loss = torchani.training.TransformedLoss(
torchani.training.MSELoss('energies'),
loss = torchani.ignite.TransformedLoss(
torchani.ignite.MSELoss('energies'),
lambda x: torch.exp(x) - 1)
trainer = create_supervised_trainer(
container, optimizer, loss)
evaluator = create_supervised_evaluator(container, metrics={
'RMSE': torchani.training.RMSEMetric('energies')
'RMSE': torchani.ignite.RMSEMetric('energies')
})
@trainer.on(Events.COMPLETED)
......
from .utils import EnergyShifter
from .models import ANIModel, Ensemble
from .aev import AEVComputer
from . import training
from . import ignite
from . import utils
from . import neurochem
from . import data
from .neurochem import buildins
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'buildins',
'training', 'utils', 'neurochem']
'ignite', 'utils', 'neurochem', 'data']
from torch.utils.data import Dataset
from os.path import join, isfile, isdir
import os
from .pyanitools import anidataloader
from ._pyanitools import anidataloader
import torch
import torch.utils.data as data
import pickle
from .. import utils
from . import utils
def chunk_counts(counts, split):
......@@ -77,14 +75,11 @@ def split_batch(natoms, species, coordinates):
class BatchedANIDataset(Dataset):
def __init__(self, path, species, batch_size, shuffle=True,
properties=['energies'], transform=(),
def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=['energies'], transform=(),
dtype=torch.get_default_dtype(), device=torch.device('cpu')):
super(BatchedANIDataset, self).__init__()
self.path = path
self.species = species
self.species_indices = {
self.species[i]: i for i in range(len(self.species))}
self.batch_size = batch_size
self.shuffle = shuffle
self.properties = properties
......@@ -108,13 +103,12 @@ class BatchedANIDataset(Dataset):
properties = {k: [] for k in self.properties}
for f in files:
for m in anidataloader(f):
species = m['species']
indices = [self.species_indices[i] for i in species]
species = torch.tensor(indices, dtype=torch.long)
coordinates = torch.from_numpy(m['coordinates'])
species_coordinates.append((species, coordinates))
s = species_tensor_converter(m['species'])
c = torch.from_numpy(m['coordinates']).to(torch.double)
species_coordinates.append((s, c))
for i in properties:
properties[i].append(torch.from_numpy(m[i]))
p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p)
species, coordinates = utils.pad_and_batch(species_coordinates)
for i in properties:
properties[i] = torch.cat(properties[i])
......@@ -136,9 +130,10 @@ class BatchedANIDataset(Dataset):
# convert to desired dtype
species = species
coordinates = coordinates.to(dtype)
properties = {k: properties[k].to(dtype) for k in properties}
for k in properties:
properties[k] = properties[k].to(dtype)
# split into minibatches, and strip reduncant padding
# split into minibatches, and strip redundant padding
natoms = (species >= 0).to(torch.long).sum(1)
batches = []
num_batches = (conformations + batch_size - 1) // batch_size
......@@ -146,6 +141,7 @@ class BatchedANIDataset(Dataset):
start = i * batch_size
end = min((i + 1) * batch_size, conformations)
natoms_batch = natoms[start:end]
# sort batch by number of atoms to prepare for splitting
natoms_batch, indices = natoms_batch.sort()
species_batch = species[start:end, ...].index_select(0, indices)
coordinates_batch = coordinates[start:end, ...] \
......@@ -172,24 +168,3 @@ class BatchedANIDataset(Dataset):
def __len__(self):
return len(self.batches)
def load_or_create(checkpoint, batch_size, species, dataset_path,
*args, **kwargs):
"""Generate a 80-10-10 split of the dataset, and checkpoint
the resulting dataset"""
if not os.path.isfile(checkpoint):
full_dataset = BatchedANIDataset(dataset_path, species, batch_size,
*args, **kwargs)
training_size = int(len(full_dataset) * 0.8)
validation_size = int(len(full_dataset) * 0.1)
testing_size = len(full_dataset) - training_size - validation_size
lengths = [training_size, validation_size, testing_size]
subsets = data.random_split(full_dataset, lengths)
with open(checkpoint, 'wb') as f:
pickle.dump(subsets, f)
# load dataset from checkpoint file
with open(checkpoint, 'rb') as f:
training, validation, testing = pickle.load(f)
return training, validation, testing
import torch
from . import utils
from torch.nn.modules.loss import _Loss
from ignite.metrics.metric import Metric
from ignite.metrics import RootMeanSquaredError
import torch
class Container(torch.nn.ModuleDict):
def __init__(self, modules):
super(Container, self).__init__(modules)
def forward(self, species_coordinates):
results = {k: [] for k in self}
for sc in species_coordinates:
for k in self:
_, result = self[k](sc)
results[k].append(result)
for k in self:
results[k] = torch.cat(results[k])
results['species'], results['coordinates'] = \
utils.pad_and_batch(species_coordinates)
return results
class DictLoss(_Loss):
......
......@@ -56,9 +56,9 @@ class Constants(Mapping):
def __getitem__(self, item):
return getattr(self, item)
def species_to_tensor(self, species, device):
def species_to_tensor(self, species):
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long, device=device)
return torch.tensor(rev, dtype=torch.long)
def load_sae(filename):
......
from .container import Container
from .data import BatchedANIDataset, load_or_create
from .loss_metrics import DictLoss, DictMetric, MSELoss, RMSEMetric, \
TransformedLoss
from . import pyanitools
__all__ = ['Container', 'BatchedANIDataset', 'load_or_create', 'DictLoss',
'DictMetric', 'MSELoss', 'RMSEMetric', 'TransformedLoss',
'pyanitools']
import torch
from .. import utils
class Container(torch.nn.Module):
def __init__(self, models):
super(Container, self).__init__()
self.keys = models.keys()
for i in models:
setattr(self, 'model_' + i, models[i])
def forward(self, species_coordinates):
results = {k: [] for k in self.keys}
for sc in species_coordinates:
for k in self.keys:
model = getattr(self, 'model_' + k)
_, result = model(sc)
results[k].append(result)
results['species'], results['coordinates'] = \
utils.pad_and_batch(species_coordinates)
for k in self.keys:
results[k] = torch.cat(results[k])
return results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment