Unverified Commit 327a9b20 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

rewrite dataloader API to improve performance and reduce code complexity (#18)

parent f6ef4ebb
import sys
if sys.version_info.major >= 3:
import torchani
import unittest
import tempfile
import os
import torch
import torchani.pyanitools as pyanitools
import unittest
import torchani.data
from math import ceil
from bisect import bisect
from pickle import dump, load
path = os.path.dirname(os.path.realpath(__file__))
dataset_dir = os.path.join(path, 'dataset')
path = os.path.join(path, 'dataset')
class TestDataset(unittest.TestCase):
def setUp(self, data_path=dataset_dir):
self.data_path = data_path
self.ds = torchani.data.load_dataset(data_path)
def testLen(self):
# compute data length using Dataset
l1 = len(self.ds)
# compute data lenght using pyanitools
l2 = 0
for f in os.listdir(self.data_path):
f = os.path.join(self.data_path, f)
if os.path.isfile(f) and \
(f.endswith('.h5') or f.endswith('.hdf5')):
for j in pyanitools.anidataloader(f):
l2 += j['energies'].shape[0]
# compute data length using iterator
l3 = len(list(self.ds))
# these lengths should match
self.assertEqual(l1, l2)
self.assertEqual(l1, l3)
def testNumChunks(self):
chunksize = 64
# compute number of chunks using batch sampler
bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
l1 = len(bs)
# compute number of chunks using pyanitools
l2 = 0
for f in os.listdir(self.data_path):
f = os.path.join(self.data_path, f)
if os.path.isfile(f) and \
(f.endswith('.h5') or f.endswith('.hdf5')):
for j in pyanitools.anidataloader(f):
conformations = j['energies'].shape[0]
l2 += ceil(conformations / chunksize)
# compute number of chunks using iterator
l3 = len(list(bs))
# these lengths should match
self.assertEqual(l1, l2)
self.assertEqual(l1, l3)
def testNumBatches(self):
chunksize = 64
batch_chunks = 4
# compute number of batches using batch sampler
bs = torchani.data.BatchSampler(self.ds, chunksize, batch_chunks)
l1 = len(bs)
# compute number of batches by simple math
bs2 = torchani.data.BatchSampler(self.ds, chunksize, 1)
l2 = ceil(len(bs2) / batch_chunks)
# compute number of batches using iterator
l3 = len(list(bs))
# these lengths should match
self.assertEqual(l1, l2)
self.assertEqual(l1, l3)
def testBatchSize1(self):
bs = torchani.data.BatchSampler(self.ds, 1, 1)
self.assertEqual(len(bs), len(self.ds))
def testSplitSize(self):
chunksize = 64
bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
chunks = len(bs)
ds1, ds2 = torchani.data.random_split(
self.ds, [200, chunks-200], chunksize)
bs1 = torchani.data.BatchSampler(ds1, chunksize, 1)
bs2 = torchani.data.BatchSampler(ds2, chunksize, 1)
self.assertEqual(len(bs1), 200)
self.assertEqual(len(bs2), chunks-200)
def testSplitNoOverlap(self):
chunksize = 64
bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
chunks = len(bs)
ds1, ds2 = torchani.data.random_split(
self.ds, [200, chunks-200], chunksize)
indices1 = ds1.dataset.indices
indices2 = ds2.dataset.indices
self.assertEqual(len(indices1), len(ds1))
self.assertEqual(len(indices2), len(ds2))
self.assertEqual(len(indices1), len(set(indices1)))
self.assertEqual(len(indices2), len(set(indices2)))
self.assertEqual(len(self.ds), len(set(indices1+indices2)))
def _testMolSizes(self, ds):
for i in range(len(ds)):
left = bisect(ds.cumulative_sizes, i)
moli = ds[i][0].item()
for j in range(len(ds)):
left2 = bisect(ds.cumulative_sizes, j)
molj = ds[j][0].item()
if left == left2:
self.assertEqual(moli, molj)
else:
if moli == molj:
print(i, j)
self.assertNotEqual(moli, molj)
def testMolSizes(self):
chunksize = 8
bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
chunks = len(bs)
ds1, ds2 = torchani.data.random_split(
self.ds, [50, chunks-50], chunksize)
self._testMolSizes(ds1)
def testSaveLoad(self):
chunksize = 8
bs = torchani.data.BatchSampler(self.ds, chunksize, 1)
chunks = len(bs)
ds1, ds2 = torchani.data.random_split(
self.ds, [50, chunks-50], chunksize)
tmpdir = tempfile.TemporaryDirectory()
tmpdirname = tmpdir.name
filename = os.path.join(tmpdirname, 'test.obj')
def _test_chunksize(self, chunksize):
ds = torchani.data.ANIDataset(path, chunksize)
for i in ds:
self.assertLessEqual(i[0].shape[0], chunksize)
with open(filename, 'wb') as f:
dump(ds1, f)
def testChunk64(self):
self._test_chunksize(64)
with open(filename, 'rb') as f:
ds1_loaded = load(f)
def testChunk128(self):
self._test_chunksize(128)
self.assertEqual(len(ds1), len(ds1_loaded))
self.assertListEqual(ds1.sizes, ds1_loaded.sizes)
self.assertIsInstance(ds1_loaded, torchani.data.ANIDataset)
def testChunk32(self):
self._test_chunksize(32)
for i in range(len(ds1)):
i1 = ds1[i]
i2 = ds1_loaded[i]
molid1 = i1[0].item()
molid2 = i2[0].item()
self.assertEqual(molid1, molid2)
xyz1 = i1[1]
xyz2 = i2[1]
maxdiff = torch.max(torch.abs(xyz1-xyz2)).item()
self.assertEqual(maxdiff, 0)
e1 = i1[2].item()
e2 = i2[2].item()
self.assertEqual(e1, e2)
def testChunk256(self):
self._test_chunksize(256)
if __name__ == '__main__':
unittest.main()
......@@ -6,6 +6,5 @@ from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
__all__ = ['SortedAEV', 'EnergyShifter', 'ModelOnAEV',
'PerSpeciesFromNeuroChem', 'data', 'buildin_const_file',
'buildin_sae_file', 'buildin_network_dir', 'buildin_dataset_dir',
'buildin_model_prefix', 'buildin_ensembles', 'default_dtype',
'default_device']
'buildin_sae_file', 'buildin_network_dir', 'buildin_model_prefix',
'buildin_ensembles', 'default_dtype', 'default_device']
from .pyanitools import anidataloader
from os import listdir
from torch.utils.data import Dataset
from os.path import join, isfile, isdir
from torch import tensor, full_like, long
from torch.utils.data import Dataset, Subset, TensorDataset, ConcatDataset
from torch.utils.data.dataloader import default_collate
from math import ceil
from . import default_dtype
from random import shuffle
from itertools import chain, accumulate
from os import listdir
from .pyanitools import anidataloader
import torch
class ANIDataset(Dataset):
"""Dataset with extra information for ANI applications
Attributes
----------
dataset : Dataset
The dataset
sizes : sequence
Number of conformations for each molecule
cumulative_sizes : sequence
Cumulative sizes
"""
def __init__(self, dataset, sizes, species):
def __init__(self, path, chunk_size, randomize_chunk=True):
super(ANIDataset, self).__init__()
self.dataset = dataset
self.sizes = sizes
self.cumulative_sizes = list(accumulate(sizes))
self.species = species
def __getitem__(self, idx):
return self.dataset[idx]
def __len__(self):
return len(self.dataset)
def load_dataset(path, dtype=default_dtype):
"""The returned dataset has cumulative_sizes and molecule_sizes"""
# get name of files storing data
files = []
if isdir(path):
for f in listdir(path):
f = join(path, f)
if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
files.append(f)
elif isfile(path):
files = [path]
else:
raise ValueError('Bad path')
# read tensors from file and build a dataset
species = []
molecule_id = 0
datasets = []
for f in files:
for m in anidataloader(f):
coordinates = tensor(m['coordinates'], dtype=dtype)
energies = tensor(m['energies'], dtype=dtype)
_molecule_id = full_like(energies, molecule_id).type(long)
datasets.append(TensorDataset(_molecule_id, coordinates, energies))
species.append(m['species'])
molecule_id += 1
dataset = ConcatDataset(datasets)
sizes = [len(x) for x in dataset.datasets]
return ANIDataset(dataset, sizes, species)
# get name of files storing data
files = []
if isdir(path):
for f in listdir(path):
f = join(path, f)
if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
files.append(f)
elif isfile(path):
files = [path]
else:
raise ValueError('Bad path')
# generate chunks
chunks = []
for f in files:
for m in anidataloader(f):
xyz = torch.from_numpy(m['coordinates'])
conformations = xyz.shape[0]
energies = torch.from_numpy(m['energies'])
species = m['species']
if randomize_chunk:
indices = torch.randperm(conformations)
else:
indices = torch.arange(conformations, dtype=torch.int64)
num_chunks = (conformations + chunk_size - 1) // chunk_size
for i in range(num_chunks):
chunk_start = i * chunk_size
chunk_end = min(chunk_start + chunk_size, conformations)
chunk_indices = indices[chunk_start:chunk_end]
chunk_xyz = xyz.index_select(0, chunk_indices)
chunk_energies = energies.index_select(0, chunk_indices)
chunks.append((chunk_xyz, chunk_energies, species))
self.chunks = chunks
class BatchSampler(object):
def __init__(self, source, chunk_size, batch_chunks):
if not isinstance(source, ANIDataset):
raise ValueError("BatchSampler must take ANIDataset as input")
self.source = source
self.chunk_size = chunk_size
self.batch_chunks = batch_chunks
def _concated_index(self, molecule, conformation):
"""
Get the index in the dataset of the specified conformation
of the specified molecule.
"""
src = self.source
cumulative_sizes = [0] + src.cumulative_sizes
return cumulative_sizes[molecule] + conformation
def __iter__(self):
molecules = len(self.source.sizes)
sizes = self.source.sizes
"""Number of conformations of each molecule"""
unfinished = list(zip(range(molecules), [0] * molecules))
"""List of pairs (molecule, progress) storing the current progress
of iterating each molecules."""
batch = []
batch_molecules = 0
"""The number of molecules already in batch"""
while len(unfinished) > 0:
new_unfinished = []
for molecule, progress in unfinished:
size = sizes[molecule]
# the last incomplete chunk is not dropped
end = min(progress + self.chunk_size, size)
if end < size:
new_unfinished.append((molecule, end))
batch += [self._concated_index(molecule, x)
for x in range(progress, end)]
batch_molecules += 1
if batch_molecules >= self.batch_chunks:
yield batch
batch = []
batch_molecules = 0
unfinished = new_unfinished
# the last incomplete batch is not dropped
if len(batch) > 0:
yield batch
def __getitem__(self, idx):
return self.chunks[idx]
def __len__(self):
sizes = self.source.sizes
chunks = [ceil(x/self.chunk_size) for x in sizes]
chunks = sum(chunks)
return ceil(chunks / self.batch_chunks)
def collate(batch):
by_molecules = {}
for molecule_id, xyz, energy in batch:
molecule_id = molecule_id.item()
if molecule_id not in by_molecules:
by_molecules[molecule_id] = []
by_molecules[molecule_id].append((xyz, energy))
for i in by_molecules:
by_molecules[i] = default_collate(by_molecules[i])
return by_molecules
def random_split(dataset, num_chunks, chunk_size):
"""
Randomly split a dataset into non-overlapping new datasets of given lengths
The splitting is by chunk, which makes it possible for batching: The whole
dataset is first splitted into chunks of specified size, each chunk are
different conformation of the same isomer/molecule, then these chunks are
randomly shuffled and splitted accorting to the given `num_chunks`. After
splitted, chunks belong to the same molecule/isomer of the same subset will
be merged to allow larger batch.
Parameters
----------
dataset : Dataset:
Dataset to be split
num_chunks : sequence
Number of chuncks of splits to be produced
chunk_size : integer
Size of each chunk
"""
chunks = list(BatchSampler(dataset, chunk_size, 1))
shuffle(chunks)
if sum(num_chunks) != len(chunks):
raise ValueError(
"""Sum of input number of chunks does not equal the length of the
total dataset!""")
offset = 0
subsets = []
for i in num_chunks:
_chunks = chunks[offset:offset+i]
offset += i
# merge chunks by molecule
by_molecules = {}
for chunk in _chunks:
molecule_id = dataset[chunk[0]][0].item()
if molecule_id not in by_molecules:
by_molecules[molecule_id] = []
by_molecules[molecule_id] += chunk
_chunks = list(by_molecules.values())
shuffle(_chunks)
# construct subset
sizes = [len(j) for j in _chunks]
indices = list(chain.from_iterable(_chunks))
_dataset = Subset(dataset, indices)
_dataset = ANIDataset(_dataset, sizes, dataset.species)
subsets.append(_dataset)
return subsets
return len(self.chunks)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment