Commit b9e2c259 authored by Richard Xue's avatar Richard Xue Committed by Gao, Xiang
Browse files

New dataset API, cached dataset and shuffled dataset (#284)

parent f825c99e
......@@ -2,5 +2,5 @@
python -m pip install --upgrade pip
pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install tqdm pyyaml future
pip install tqdm pyyaml future pkbar
pip install 'ase<=3.17'
\ No newline at end of file
......@@ -2,5 +2,5 @@
python -m pip install --upgrade pip
pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
pip install tqdm pyyaml future
pip install tqdm pyyaml future pkbar
pip2 install 'ase<=3.17'
\ No newline at end of file
......@@ -24,6 +24,9 @@ Datasets
========
.. automodule:: torchani.data
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset
.. autofunction:: torchani.data.load_ani_dataset
.. autofunction:: torchani.data.create_aev_cache
.. autoclass:: torchani.data.BatchedANIDataset
......
......@@ -5,7 +5,7 @@ with-coverage=1
cover-package=torchani
[flake8]
ignore = E501
ignore = E501, W503
exclude =
.git,
__pycache__,
......
......@@ -31,6 +31,7 @@ setup_attrs = {
'h5py',
'pytorch-ignite-nightly',
'pillow',
'pkbar'
],
}
......
import torchani
import unittest
import pkbar
import torch
import os
path = os.path.dirname(os.path.realpath(__file__))
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5')
batch_size = 2560
chunk_threshold = 5
class TestFindThreshold(unittest.TestCase):
def setUp(self):
print('.. check find threshold to split chunks')
def testFindThreshould(self):
torchani.data.find_threshold(dspath, batch_size=batch_size, threshold_max=10)
class TestShuffledData(unittest.TestCase):
def setUp(self):
print('.. setup shuffle dataset')
self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2)
self.chunks, self.properties = iter(self.ds).next()
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), chunk[0].dtype, list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
# check shape
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
for key, value in self.properties.items():
print(key, list(value.size()), value.dtype)
self.assertEqual(value.dtype, torch.float32)
self.assertEqual(len(value.shape), 1)
self.assertEqual(value.shape[0], batch_len)
def testLoadDataset(self):
print('=> test loading all dataset')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(self.ds), batch_size),
len(self.ds))
for i, _ in enumerate(self.ds):
pbar.update(i)
def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks):
species, _ = chunk
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
class TestCachedData(unittest.TestCase):
def setUp(self):
print('.. setup cached dataset')
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu', chunk_threshold=chunk_threshold)
self.chunks, self.properties = self.ds[0]
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), chunk[0].dtype, list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
# check shape
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
for key, value in self.properties.items():
print(key, list(value.size()), value.dtype)
self.assertEqual(value.dtype, torch.float32)
self.assertEqual(len(value.shape), 1)
self.assertEqual(value.shape[0], batch_len)
def testLoadDataset(self):
print('=> test loading all dataset')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(self.ds), batch_size),
len(self.ds))
for i, _ in enumerate(self.ds):
pbar.update(i)
def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks):
species, _ = chunk
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
if __name__ == "__main__":
unittest.main()
import torch
import ignite
import torchani
import time
import timeit
import tqdm
import argparse
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
help='Path of the dataset, can a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--batch_size',
help='Number of conformations of each batch',
default=1024, type=int)
parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
ani1x = torchani.models.ANI1x()
consts = ani1x.consts
aev_computer = ani1x.aev_computer
shift_energy = ani1x.energy_shifter
import pkbar
def atomic():
......@@ -39,45 +19,6 @@ def atomic():
return model
model = torchani.ANIModel([atomic() for _ in range(4)])
class Flatten(torch.nn.Module):
def forward(self, x):
return x[0], x[1].flatten()
nnp = torch.nn.Sequential(aev_computer, model, Flatten()).to(device)
dataset = torchani.data.load_ani_dataset(
parser.dataset_path, consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.MSELoss('energies'))
@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
trainer.state.tqdm.update(1)
@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
trainer.state.tqdm.close()
timers = {}
def time_func(key, func):
timers[key] = 0
......@@ -91,27 +32,153 @@ def time_func(key, func):
return wrapper
# enable timers
torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine', torchani.aev.cutoff_cosine)
torchani.aev.radial_terms = time_func('torchani.aev.radial_terms', torchani.aev.radial_terms)
torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms)
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('torchani.aev.convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward)
# run it!
start = timeit.default_timer()
trainer.run(dataset, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
for k in timers:
def hartree2kcal(x):
return 627.509 * x
if __name__ == "__main__":
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
help='Path of the dataset, can a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('-b', '--batch_size',
help='Number of conformations of each batch',
default=2560, type=int)
parser.add_argument('-o', '--original_dataset_api',
help='use original dataset api',
dest='dataset',
action='store_const',
const='original')
parser.add_argument('-s', '--shuffle_dataset_api',
help='use shuffle dataset api',
dest='dataset',
action='store_const',
const='shuffle')
parser.add_argument('-c', '--cache_dataset_api',
help='use cache dataset api',
dest='dataset',
action='store_const',
const='cache')
parser.set_defaults(dataset='shuffle')
parser.add_argument('-n', '--num_epochs',
help='epochs',
default=1, type=int)
parser = parser.parse_args()
Rcr = 5.2000e+00
Rca = 3.5000e+00
EtaR = torch.tensor([1.6000000e+01], device=parser.device)
ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=parser.device)
Zeta = torch.tensor([3.2000000e+01], device=parser.device)
ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=parser.device)
EtaA = torch.tensor([8.0000000e+00], device=parser.device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=parser.device)
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
nn = torchani.ANIModel([atomic() for _ in range(4)])
model = torch.nn.Sequential(aev_computer, nn).to(parser.device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
mse = torch.nn.MSELoss(reduction='none')
timers = {}
# enable timers
torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine', torchani.aev.cutoff_cosine)
torchani.aev.radial_terms = time_func('torchani.aev.radial_terms', torchani.aev.radial_terms)
torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms)
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('torchani.aev.convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
model[0].forward = time_func('total', model[0].forward)
model[1].forward = time_func('forward', model[1].forward)
if parser.dataset == 'shuffle':
torchani.data.ShuffledDataset = time_func('data_loading', torchani.data.ShuffledDataset)
print('using shuffle dataset API')
print('=> loading dataset...')
dataset = torchani.data.ShuffledDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size,
num_workers=2)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = iter(dataset).next()
elif parser.dataset == 'original':
torchani.data.load_ani_dataset = time_func('data_loading', torchani.data.load_ani_dataset)
print('using original dataset API')
print('=> loading dataset...')
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
dataset = torchani.data.load_ani_dataset(parser.dataset_path, species_to_tensor,
parser.batch_size, device=parser.device,
transform=[energy_shifter.subtract_from_dataset])
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
elif parser.dataset == 'cache':
torchani.data.CachedDataset = time_func('data_loading', torchani.data.CachedDataset)
print('using cache dataset API')
print('=> loading dataset...')
dataset = torchani.data.CachedDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size)
print('=> caching all dataset into cpu')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(dataset), parser.batch_size),
len(dataset))
for i, t in enumerate(dataset):
pbar.update(i)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
for i, chunk in enumerate(chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), list(chunk[1].size()))
print('energies', list(properties['energies'].size()))
print('=> start training')
start = time.time()
for epoch in range(0, parser.num_epochs):
print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs))
progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)
for i, (batch_x, batch_y) in enumerate(dataset):
true_energies = batch_y['energies'].to(parser.device)
predicted_energies = []
num_atoms = []
for chunk_species, chunk_coordinates in batch_x:
chunk_species = chunk_species.to(parser.device)
chunk_coordinates = chunk_coordinates.to(parser.device)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcal((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
loss.backward()
optimizer.step()
progbar.update(i, values=[("rmse", rmse)])
stop = time.time()
print('=> more detail about benchmark')
for k in timers:
if k.startswith('torchani.'):
print(k, timers[k])
print('Total AEV:', timers['total'])
print('NN:', timers['forward'])
print('Epoch time:', elapsed)
print('{} - {:.1f}s'.format(k, timers[k]))
print('Total AEV - {:.1f}s'.format(timers['total']))
print('Data Loading - {:.1f}s'.format(timers['data_loading']))
print('NN - {:.1f}s'.format(timers['forward']))
print('Epoch time - {:.1f}s'.format(stop - start))
......@@ -11,6 +11,7 @@ import pickle
import numpy as np
from scipy.sparse import bsr_matrix
import warnings
from .new import CachedDataset, ShuffledDataset, find_threshold
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
......@@ -511,4 +512,6 @@ def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
SparseAEVCacheLoader.encode_aev, **kwargs)
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'AEVCacheLoader', 'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev']
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'AEVCacheLoader',
'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev',
'CachedDataset', 'ShuffledDataset', 'find_threshold']
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment