Unverified Commit 4f834e2c authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Improve new dataset API (#433)

parent 6b058c6e
......@@ -82,7 +82,6 @@ class Transformations:
@staticmethod
def subtract_self_energies(iter_, self_energies=None):
iter_ = list(iter_)
intercept = 0.0
if isinstance(self_energies, utils.EnergyShifter):
shifter = self_energies
......@@ -185,9 +184,6 @@ class TransformableIterable:
def __iter__(self):
return iter(self.wrapped_iter)
def __next__(self):
return next(self.wrapped_iter)
def __getattr__(self, name):
transformation = getattr(Transformations, name)
......@@ -220,6 +216,14 @@ class TransformableIterable:
def load(path, additional_properties=()):
properties = PROPERTIES + additional_properties
# https://stackoverflow.com/a/39564774
class IterableAdapter:
def __init__(self, iterator_factory):
self.iterator_factory = iterator_factory
def __iter__(self):
return self.iterator_factory()
def h5_files(path):
"""yield file name of all h5 files in a path"""
if isdir(path):
......@@ -232,7 +236,7 @@ def load(path, additional_properties=()):
def molecules():
for f in h5_files(path):
anidata = anidataloader(f)
anidata_size = anidata.size()
anidata_size = anidata.group_size()
use_pbar = PKBAR_INSTALLED and verbose
if use_pbar:
pbar = pkbar.Pbar('=> loading {}, total molecules: {}'.format(f, anidata_size), anidata_size)
......@@ -252,7 +256,7 @@ def load(path, additional_properties=()):
ret[k] = m[k][i]
yield ret
return TransformableIterable(conformations())
return TransformableIterable(IterableAdapter(lambda: conformations()))
__all__ = ['load']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment