Unverified Commit f8edffe3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Improve more on new dataset API (#434)

* Improve new dataset API

* Improve more on new dataset API

* fix

* fix reentrance

* Allow all intermediate state of transformation to be reentered

* Add length inference

* fix

* split by ratio

* add dataloader example

* add test for data loader
parent 4f834e2c
...@@ -34,3 +34,5 @@ dist ...@@ -34,3 +34,5 @@ dist
*.qdstrm *.qdstrm
*.zip *.zip
Untitled.ipynb Untitled.ipynb
/nnp_training.py
/test*.py
...@@ -82,9 +82,7 @@ except NameError: ...@@ -82,9 +82,7 @@ except NameError:
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560 batch_size = 2560
dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle() training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training = training.collate(batch_size).cache() training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache() validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
......
...@@ -49,9 +49,7 @@ dspath = os.path.join(path, '../dataset/ani-1x/sample.h5') ...@@ -49,9 +49,7 @@ dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 2560 batch_size = 2560
dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle() training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training = training.collate(batch_size).cache() training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache() validation = validation.collate(batch_size).cache()
......
import os import os
import torch
import torchani import torchani
import unittest import unittest
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, 'dataset/ani-1x/sample.h5') dataset_path = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 256 batch_size = 256
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
consts = ani1x.consts consts = ani1x.consts
...@@ -34,6 +35,87 @@ class TestData(unittest.TestCase): ...@@ -34,6 +35,87 @@ class TestData(unittest.TestCase):
non_padding = (species >= 0)[:, -1].nonzero() non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0) self.assertGreater(non_padding.numel(), 0)
def testReEnter(self):
# make sure that a dataset can be iterated multiple times
ds = torchani.data.load(dataset_path)
for d in ds:
pass
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.subtract_self_energies(sae_dict)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.species_to_indices()
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.shuffle()
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.collate(batch_size)
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
ds = ds.cache()
entered = False
for d in ds:
entered = True
pass
self.assertTrue(entered)
entered = False
for d in ds:
entered = True
self.assertTrue(entered)
def testShapeInference(self):
shifter = torchani.EnergyShifter(None)
ds = torchani.data.load(dataset_path).subtract_self_energies(shifter)
len(ds)
ds = ds.species_to_indices()
len(ds)
ds = ds.shuffle()
len(ds)
ds = ds.collate(batch_size)
len(ds)
def testDataloader(self):
shifter = torchani.EnergyShifter(None)
dataset = list(torchani.data.load(dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle())
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
for i in loader:
pass
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -25,7 +25,7 @@ You can also use `split` to split the iterable to pieces. Use `split` as: ...@@ -25,7 +25,7 @@ You can also use `split` to split the iterable to pieces. Use `split` as:
.. code-block:: python .. code-block:: python
it.split(size1, size2, None) it.split(ratio1, ratio2, None)
where the None in the end indicate that we want to use all of the the rest where the None in the end indicate that we want to use all of the the rest
...@@ -34,17 +34,23 @@ Example: ...@@ -34,17 +34,23 @@ Example:
.. code-block:: python .. code-block:: python
energy_shifter = torchani.utils.EnergyShifter(None) energy_shifter = torchani.utils.EnergyShifter(None)
dataset = torchani.data.load(path).subtract_self_energies(energy_shifter).species_to_indices().shuffle() training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(int(0.8 * size), None)
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training = training.collate(batch_size).cache() training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache() validation = validation.collate(batch_size).cache()
If the above approach takes too much memory for you, you can then use dataloader
with multiprocessing to achieve comparable performance with less memory usage:
.. code-block:: python
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
training = torch.utils.data.DataLoader(list(training), batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
validation = torch.utils.data.DataLoader(list(validation), batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
""" """
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
import os import os
from ._pyanitools import anidataloader from ._pyanitools import anidataloader
import torch
from .. import utils from .. import utils
import importlib import importlib
import functools import functools
...@@ -52,6 +58,7 @@ import math ...@@ -52,6 +58,7 @@ import math
import random import random
from collections import Counter from collections import Counter
import numpy import numpy
import gc
PKBAR_INSTALLED = importlib.util.find_spec('pkbar') is not None # type: ignore PKBAR_INSTALLED = importlib.util.find_spec('pkbar') is not None # type: ignore
if PKBAR_INSTALLED: if PKBAR_INSTALLED:
...@@ -69,26 +76,59 @@ PADDING = { ...@@ -69,26 +76,59 @@ PADDING = {
} }
def collate_fn(samples):
return utils.stack_with_padding(samples, PADDING)
class IterableAdapter:
"""https://stackoverflow.com/a/39564774"""
def __init__(self, iterable_factory, length=None):
self.iterable_factory = iterable_factory
self.length = length
def __iter__(self):
return iter(self.iterable_factory())
class IterableAdapterWithLength(IterableAdapter):
def __init__(self, iterable_factory, length):
super().__init__(iterable_factory)
self.length = length
def __len__(self):
return self.length
class Transformations: class Transformations:
"""Convert one reenterable iterable to another reenterable iterable"""
@staticmethod @staticmethod
def species_to_indices(iter_, species_order=('H', 'C', 'N', 'O', 'F', 'Cl', 'S')): def species_to_indices(reenterable_iterable, species_order=('H', 'C', 'N', 'O', 'F', 'Cl', 'S')):
if species_order == 'periodic_table': if species_order == 'periodic_table':
species_order = utils.PERIODIC_TABLE species_order = utils.PERIODIC_TABLE
idx = {k: i for i, k in enumerate(species_order)} idx = {k: i for i, k in enumerate(species_order)}
for d in iter_:
def reenterable_iterable_factory():
for d in reenterable_iterable:
d['species'] = numpy.array([idx[s] for s in d['species']]) d['species'] = numpy.array([idx[s] for s in d['species']])
yield d yield d
try:
return IterableAdapterWithLength(reenterable_iterable_factory, len(reenterable_iterable))
except TypeError:
return IterableAdapter(reenterable_iterable_factory)
@staticmethod @staticmethod
def subtract_self_energies(iter_, self_energies=None): def subtract_self_energies(reenterable_iterable, self_energies=None):
intercept = 0.0 intercept = 0.0
shape_inference = False
if isinstance(self_energies, utils.EnergyShifter): if isinstance(self_energies, utils.EnergyShifter):
shape_inference = True
shifter = self_energies shifter = self_energies
self_energies = {} self_energies = {}
counts = {} counts = {}
Y = [] Y = []
for n, d in enumerate(iter_): for n, d in enumerate(reenterable_iterable):
species = d['species'] species = d['species']
count = Counter() count = Counter()
for s in species: for s in species:
...@@ -115,22 +155,28 @@ class Transformations: ...@@ -115,22 +155,28 @@ class Transformations:
for s, e in zip(species, sae_): for s, e in zip(species, sae_):
self_energies[s] = e self_energies[s] = e
shifter.__init__(sae, shifter.fit_intercept) shifter.__init__(sae, shifter.fit_intercept)
for d in iter_: gc.collect()
def reenterable_iterable_factory():
for d in reenterable_iterable:
e = intercept e = intercept
for s in d['species']: for s in d['species']:
e += self_energies[s] e += self_energies[s]
d['energies'] -= e d['energies'] -= e
yield d yield d
if shape_inference:
return IterableAdapterWithLength(reenterable_iterable_factory, n)
return IterableAdapter(reenterable_iterable_factory)
@staticmethod @staticmethod
def remove_outliers(iter_, threshold1=15.0, threshold2=8.0): def remove_outliers(reenterable_iterable, threshold1=15.0, threshold2=8.0):
assert 'subtract_self_energies', "Transformation remove_outliers can only run after subtract_self_energies" assert 'subtract_self_energies', "Transformation remove_outliers can only run after subtract_self_energies"
# pass 1: remove everything that has per-atom energy > threshold1 # pass 1: remove everything that has per-atom energy > threshold1
def scaled_energy(x): def scaled_energy(x):
num_atoms = len(x['species']) num_atoms = len(x['species'])
return abs(x['energies']) / math.sqrt(num_atoms) return abs(x['energies']) / math.sqrt(num_atoms)
filtered = [x for x in iter_ if scaled_energy(x) < threshold1] filtered = IterableAdapter(lambda: (x for x in reenterable_iterable if scaled_energy(x) < threshold1))
# pass 2: compute those that are outside the mean by threshold2 * std # pass 2: compute those that are outside the mean by threshold2 * std
n = 0 n = 0
...@@ -143,46 +189,61 @@ class Transformations: ...@@ -143,46 +189,61 @@ class Transformations:
mean /= n mean /= n
std = math.sqrt(std / n - mean ** 2) std = math.sqrt(std / n - mean ** 2)
return filter(lambda x: abs(x['energies'] - mean) < threshold2 * std, filtered) return IterableAdapter(lambda: filter(lambda x: abs(x['energies'] - mean) < threshold2 * std, filtered))
@staticmethod @staticmethod
def shuffle(iter_): def shuffle(reenterable_iterable):
list_ = list(iter_) list_ = list(reenterable_iterable)
del reenterable_iterable
gc.collect()
random.shuffle(list_) random.shuffle(list_)
return list_ return list_
@staticmethod @staticmethod
def cache(iter_): def cache(reenterable_iterable):
return list(iter_) ret = list(reenterable_iterable)
del reenterable_iterable
gc.collect()
return ret
@staticmethod @staticmethod
def collate(iter_, batch_size): def collate(reenterable_iterable, batch_size):
def reenterable_iterable_factory():
batch = [] batch = []
i = 0 i = 0
for d in iter_: for d in reenterable_iterable:
d = {k: torch.as_tensor(d[k]) for k in d}
batch.append(d) batch.append(d)
i += 1 i += 1
if i == batch_size: if i == batch_size:
i = 0 i = 0
yield utils.stack_with_padding(batch, PADDING) yield collate_fn(batch)
batch = [] batch = []
if len(batch) > 0: if len(batch) > 0:
yield utils.stack_with_padding(batch, PADDING) yield collate_fn(batch)
try:
length = (len(reenterable_iterable) + batch_size - 1) // batch_size
return IterableAdapterWithLength(reenterable_iterable_factory, length)
except TypeError:
return IterableAdapter(reenterable_iterable_factory)
@staticmethod @staticmethod
def pin_memory(iter_): def pin_memory(reenterable_iterable):
for d in iter_: def reenterable_iterable_factory():
for d in reenterable_iterable:
yield {k: d[k].pin_memory() for k in d} yield {k: d[k].pin_memory() for k in d}
try:
return IterableAdapterWithLength(reenterable_iterable_factory, len(reenterable_iterable))
except TypeError:
return IterableAdapter(reenterable_iterable_factory)
class TransformableIterable: class TransformableIterable:
def __init__(self, wrapped_iter, transformations=()): def __init__(self, wrapped_iterable, transformations=()):
self.wrapped_iter = wrapped_iter self.wrapped_iterable = wrapped_iterable
self.transformations = transformations self.transformations = transformations
def __iter__(self): def __iter__(self):
return iter(self.wrapped_iter) return iter(self.wrapped_iterable)
def __getattr__(self, name): def __getattr__(self, name):
transformation = getattr(Transformations, name) transformation = getattr(Transformations, name)
...@@ -190,40 +251,35 @@ class TransformableIterable: ...@@ -190,40 +251,35 @@ class TransformableIterable:
@functools.wraps(transformation) @functools.wraps(transformation)
def f(*args, **kwargs): def f(*args, **kwargs):
return TransformableIterable( return TransformableIterable(
transformation(self, *args, **kwargs), transformation(self.wrapped_iterable, *args, **kwargs),
self.transformations + (name,)) self.transformations + (name,))
return f return f
def split(self, *nums): def split(self, *nums):
length = len(self)
iters = [] iters = []
self_iter = iter(self) self_iter = iter(self)
for n in nums: for n in nums:
list_ = [] list_ = []
if n is not None: if n is not None:
for _ in range(n): for _ in range(int(n * length)):
list_.append(next(self_iter)) list_.append(next(self_iter))
else: else:
for i in self_iter: for i in self_iter:
list_.append(i) list_.append(i)
iters.append(TransformableIterable(list_, self.transformations + ('split',))) iters.append(TransformableIterable(list_, self.transformations + ('split',)))
del self_iter
gc.collect()
return iters return iters
def __len__(self): def __len__(self):
return len(self.wrapped_iter) return len(self.wrapped_iterable)
def load(path, additional_properties=()): def load(path, additional_properties=()):
properties = PROPERTIES + additional_properties properties = PROPERTIES + additional_properties
# https://stackoverflow.com/a/39564774
class IterableAdapter:
def __init__(self, iterator_factory):
self.iterator_factory = iterator_factory
def __iter__(self):
return self.iterator_factory()
def h5_files(path): def h5_files(path):
"""yield file name of all h5 files in a path""" """yield file name of all h5 files in a path"""
if isdir(path): if isdir(path):
...@@ -259,4 +315,4 @@ def load(path, additional_properties=()): ...@@ -259,4 +315,4 @@ def load(path, additional_properties=()):
return TransformableIterable(IterableAdapter(lambda: conformations())) return TransformableIterable(IterableAdapter(lambda: conformations()))
__all__ = ['load'] __all__ = ['load', 'collate_fn']
...@@ -8,11 +8,15 @@ from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2 ...@@ -8,11 +8,15 @@ from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2
from .nn import SpeciesEnergies from .nn import SpeciesEnergies
def empty_list():
return []
def stack_with_padding(properties, padding): def stack_with_padding(properties, padding):
output = defaultdict(lambda: []) output = defaultdict(empty_list)
for p in properties: for p in properties:
for k, v in p.items(): for k, v in p.items():
output[k].append(v) output[k].append(torch.as_tensor(v))
for k, v in output.items(): for k, v in output.items():
if v[0].dim() == 0: if v[0].dim() == 0:
output[k] = torch.stack(v) output[k] = torch.stack(v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment