Unverified Commit 4f834e2c authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Improve new dataset API (#433)

parent 6b058c6e
...@@ -82,7 +82,6 @@ class Transformations: ...@@ -82,7 +82,6 @@ class Transformations:
@staticmethod @staticmethod
def subtract_self_energies(iter_, self_energies=None): def subtract_self_energies(iter_, self_energies=None):
iter_ = list(iter_)
intercept = 0.0 intercept = 0.0
if isinstance(self_energies, utils.EnergyShifter): if isinstance(self_energies, utils.EnergyShifter):
shifter = self_energies shifter = self_energies
...@@ -185,9 +184,6 @@ class TransformableIterable: ...@@ -185,9 +184,6 @@ class TransformableIterable:
def __iter__(self): def __iter__(self):
return iter(self.wrapped_iter) return iter(self.wrapped_iter)
def __next__(self):
return next(self.wrapped_iter)
def __getattr__(self, name): def __getattr__(self, name):
transformation = getattr(Transformations, name) transformation = getattr(Transformations, name)
...@@ -220,6 +216,14 @@ class TransformableIterable: ...@@ -220,6 +216,14 @@ class TransformableIterable:
def load(path, additional_properties=()): def load(path, additional_properties=()):
properties = PROPERTIES + additional_properties properties = PROPERTIES + additional_properties
# https://stackoverflow.com/a/39564774
class IterableAdapter:
def __init__(self, iterator_factory):
self.iterator_factory = iterator_factory
def __iter__(self):
return self.iterator_factory()
def h5_files(path): def h5_files(path):
"""yield file name of all h5 files in a path""" """yield file name of all h5 files in a path"""
if isdir(path): if isdir(path):
...@@ -232,7 +236,7 @@ def load(path, additional_properties=()): ...@@ -232,7 +236,7 @@ def load(path, additional_properties=()):
def molecules(): def molecules():
for f in h5_files(path): for f in h5_files(path):
anidata = anidataloader(f) anidata = anidataloader(f)
anidata_size = anidata.size() anidata_size = anidata.group_size()
use_pbar = PKBAR_INSTALLED and verbose use_pbar = PKBAR_INSTALLED and verbose
if use_pbar: if use_pbar:
pbar = pkbar.Pbar('=> loading {}, total molecules: {}'.format(f, anidata_size), anidata_size) pbar = pkbar.Pbar('=> loading {}, total molecules: {}'.format(f, anidata_size), anidata_size)
...@@ -252,7 +256,7 @@ def load(path, additional_properties=()): ...@@ -252,7 +256,7 @@ def load(path, additional_properties=()):
ret[k] = m[k][i] ret[k] = m[k][i]
yield ret yield ret
return TransformableIterable(conformations()) return TransformableIterable(IterableAdapter(lambda: conformations()))
__all__ = ['load'] __all__ = ['load']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment