Unverified Commit f8edffe3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Improve more on new dataset API (#434)

* Improve new dataset API

* Improve more on new dataset API

* fix

* fix reentrance

* Allow all intermediate state of transformation to be reentered

* Add length inference

* fix

* split by ratio

* add dataloader example

* add test for data loader
parent 4f834e2c
......@@ -34,3 +34,5 @@ dist
*.qdstrm
*.zip
Untitled.ipynb
/nnp_training.py
/test*.py
......@@ -82,9 +82,7 @@ except NameError:
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560
dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies)
......
......@@ -49,9 +49,7 @@ dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 2560
dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
......
import os
import torch
import torchani
import unittest
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, 'dataset/ani-1x/sample.h5')
dataset_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 256
ani1x = torchani.models.ANI1x()
consts = ani1x.consts
......@@ -34,6 +35,87 @@ class TestData(unittest.TestCase):
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
def testReEnter(self):
# make sure that a dataset can be iterated multiple times
ds = torchani.data.load(dataset_path)
for d in ds:
pass
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.subtract_self_energies(sae_dict)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.species_to_indices()
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.shuffle()
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.collate(batch_size)
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.cache()
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
def testShapeInference(self):
shifter = torchani.EnergyShifter(None)
ds = torchani.data.load(dataset_path).subtract_self_energies(shifter)
len(ds)
ds = ds.species_to_indices()
len(ds)
ds = ds.shuffle()
len(ds)
ds = ds.collate(batch_size)
len(ds)
def testDataloader(self):
shifter = torchani.EnergyShifter(None)
dataset = list(torchani.data.load(dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle())
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
for i in loader:
pass
if __name__ == '__main__':
unittest.main()
......@@ -25,7 +25,7 @@ You can also use `split` to split the iterable to pieces. Use `split` as:
.. code-block:: python
it.split(size1, size2, None)
it.split(ratio1, ratio2, None)
where the None in the end indicate that we want to use all of the the rest
......@@ -34,17 +34,23 @@ Example:
.. code-block:: python
energy_shifter = torchani.utils.EnergyShifter(None)
dataset = torchani.data.load(path).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(int(0.8 * size), None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
If the above approach takes too much memory for you, you can then use dataloader
with multiprocessing to achieve comparable performance with less memory usage:
.. code-block:: python
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training = torch.utils.data.DataLoader(list(training), batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
validation = torch.utils.data.DataLoader(list(validation), batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
"""
from os.path import join, isfile, isdir
import os
from ._pyanitools import anidataloader
import torch
from .. import utils
import importlib
import functools
......@@ -52,6 +58,7 @@ import math
import random
from collections import Counter
import numpy
import gc
PKBAR_INSTALLED = importlib.util.find_spec('pkbar') is not None # type: ignore
if PKBAR_INSTALLED:
......@@ -69,26 +76,59 @@ PADDING = {
}
def collate_fn(samples):
return utils.stack_with_padding(samples, PADDING)
class IterableAdapter:
"""https://stackoverflow.com/a/39564774"""
def __init__(self, iterable_factory, length=None):
self.iterable_factory = iterable_factory
self.length = length
def __iter__(self):
return iter(self.iterable_factory())
class IterableAdapterWithLength(IterableAdapter):
def __init__(self, iterable_factory, length):
super().__init__(iterable_factory)
self.length = length
def __len__(self):
return self.length
class Transformations:
"""Convert one reenterable iterable to another reenterable iterable"""
@staticmethod
def species_to_indices(iter_, species_order=('H', 'C', 'N', 'O', 'F', 'Cl', 'S')):
def species_to_indices(reenterable_iterable, species_order=('H', 'C', 'N', 'O', 'F', 'Cl', 'S')):
if species_order == 'periodic_table':
species_order = utils.PERIODIC_TABLE
idx = {k: i for i, k in enumerate(species_order)}
for d in iter_:
d['species'] = numpy.array([idx[s] for s in d['species']])
yield d
def reenterable_iterable_factory():
for d in reenterable_iterable:
d['species'] = numpy.array([idx[s] for s in d['species']])
yield d
try:
return IterableAdapterWithLength(reenterable_iterable_factory, len(reenterable_iterable))
except TypeError:
return IterableAdapter(reenterable_iterable_factory)
@staticmethod
def subtract_self_energies(iter_, self_energies=None):
def subtract_self_energies(reenterable_iterable, self_energies=None):
intercept = 0.0
shape_inference = False
if isinstance(self_energies, utils.EnergyShifter):
shape_inference = True
shifter = self_energies
self_energies = {}
counts = {}
Y = []
for n, d in enumerate(iter_):
for n, d in enumerate(reenterable_iterable):
species = d['species']
count = Counter()
for s in species:
......@@ -115,22 +155,28 @@ class Transformations:
for s, e in zip(species, sae_):
self_energies[s] = e
shifter.__init__(sae, shifter.fit_intercept)
for d in iter_:
e = intercept
for s in d['species']:
e += self_energies[s]
d['energies'] -= e
yield d
gc.collect()
def reenterable_iterable_factory():
for d in reenterable_iterable:
e = intercept
for s in d['species']:
e += self_energies[s]
d['energies'] -= e
yield d
if shape_inference:
return IterableAdapterWithLength(reenterable_iterable_factory, n)
return IterableAdapter(reenterable_iterable_factory)
@staticmethod
def remove_outliers(iter_, threshold1=15.0, threshold2=8.0):
def remove_outliers(reenterable_iterable, threshold1=15.0, threshold2=8.0):
assert 'subtract_self_energies', "Transformation remove_outliers can only run after subtract_self_energies"
# pass 1: remove everything that has per-atom energy > threshold1
def scaled_energy(x):
num_atoms = len(x['species'])
return abs(x['energies']) / math.sqrt(num_atoms)
filtered = [x for x in iter_ if scaled_energy(x) < threshold1]
filtered = IterableAdapter(lambda: (x for x in reenterable_iterable if scaled_energy(x) < threshold1))
# pass 2: compute those that are outside the mean by threshold2 * std
n = 0
......@@ -143,46 +189,61 @@ class Transformations:
mean /= n
std = math.sqrt(std / n - mean ** 2)
return filter(lambda x: abs(x['energies'] - mean) < threshold2 * std, filtered)
return IterableAdapter(lambda: filter(lambda x: abs(x['energies'] - mean) < threshold2 * std, filtered))
@staticmethod
def shuffle(iter_):
list_ = list(iter_)
def shuffle(reenterable_iterable):
list_ = list(reenterable_iterable)
del reenterable_iterable
gc.collect()
random.shuffle(list_)
return list_
@staticmethod
def cache(iter_):
return list(iter_)
def cache(reenterable_iterable):
ret = list(reenterable_iterable)
del reenterable_iterable
gc.collect()
return ret
@staticmethod
def collate(iter_, batch_size):
batch = []
i = 0
for d in iter_:
d = {k: torch.as_tensor(d[k]) for k in d}
batch.append(d)
i += 1
if i == batch_size:
i = 0
yield utils.stack_with_padding(batch, PADDING)
batch = []
if len(batch) > 0:
yield utils.stack_with_padding(batch, PADDING)
def collate(reenterable_iterable, batch_size):
def reenterable_iterable_factory():
batch = []
i = 0
for d in reenterable_iterable:
batch.append(d)
i += 1
if i == batch_size:
i = 0
yield collate_fn(batch)
batch = []
if len(batch) > 0:
yield collate_fn(batch)
try:
length = (len(reenterable_iterable) + batch_size - 1) // batch_size
return IterableAdapterWithLength(reenterable_iterable_factory, length)
except TypeError:
return IterableAdapter(reenterable_iterable_factory)
@staticmethod
def pin_memory(iter_):
for d in iter_:
yield {k: d[k].pin_memory() for k in d}
def pin_memory(reenterable_iterable):
def reenterable_iterable_factory():
for d in reenterable_iterable:
yield {k: d[k].pin_memory() for k in d}
try:
return IterableAdapterWithLength(reenterable_iterable_factory, len(reenterable_iterable))
except TypeError:
return IterableAdapter(reenterable_iterable_factory)
class TransformableIterable:
def __init__(self, wrapped_iter, transformations=()):
self.wrapped_iter = wrapped_iter
def __init__(self, wrapped_iterable, transformations=()):
self.wrapped_iterable = wrapped_iterable
self.transformations = transformations
def __iter__(self):
return iter(self.wrapped_iter)
return iter(self.wrapped_iterable)
def __getattr__(self, name):
transformation = getattr(Transformations, name)
......@@ -190,40 +251,35 @@ class TransformableIterable:
@functools.wraps(transformation)
def f(*args, **kwargs):
return TransformableIterable(
transformation(self, *args, **kwargs),
transformation(self.wrapped_iterable, *args, **kwargs),
self.transformations + (name,))
return f
def split(self, *nums):
length = len(self)
iters = []
self_iter = iter(self)
for n in nums:
list_ = []
if n is not None:
for _ in range(n):
for _ in range(int(n * length)):
list_.append(next(self_iter))
else:
for i in self_iter:
list_.append(i)
iters.append(TransformableIterable(list_, self.transformations + ('split',)))
del self_iter
gc.collect()
return iters
def __len__(self):
return len(self.wrapped_iter)
return len(self.wrapped_iterable)
def load(path, additional_properties=()):
properties = PROPERTIES + additional_properties
# https://stackoverflow.com/a/39564774
class IterableAdapter:
def __init__(self, iterator_factory):
self.iterator_factory = iterator_factory
def __iter__(self):
return self.iterator_factory()
def h5_files(path):
"""yield file name of all h5 files in a path"""
if isdir(path):
......@@ -259,4 +315,4 @@ def load(path, additional_properties=()):
return TransformableIterable(IterableAdapter(lambda: conformations()))
__all__ = ['load']
__all__ = ['load', 'collate_fn']
......@@ -8,11 +8,15 @@ from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2
from .nn import SpeciesEnergies
def empty_list():
return []
def stack_with_padding(properties, padding):
output = defaultdict(lambda: [])
output = defaultdict(empty_list)
for p in properties:
for k, v in p.items():
output[k].append(v)
output[k].append(torch.as_tensor(v))
for k, v in output.items():
if v[0].dim() == 0:
output[k] = torch.stack(v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment