Unverified Commit 7c597a39 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Make AEVCacheLoader a Dataset (#93)

parent 83107add
...@@ -219,7 +219,7 @@ class BatchedANIDataset(Dataset): ...@@ -219,7 +219,7 @@ class BatchedANIDataset(Dataset):
return len(self.batches) return len(self.batches)
class AEVCacheLoader: class AEVCacheLoader(Dataset):
"""Build a factory for AEV. """Build a factory for AEV.
The computation of AEV is the most time consuming part during training. The computation of AEV is the most time consuming part during training.
...@@ -233,6 +233,7 @@ class AEVCacheLoader: ...@@ -233,6 +233,7 @@ class AEVCacheLoader:
""" """
def __init__(self, disk_cache=None): def __init__(self, disk_cache=None):
super(AEVCacheLoader, self).__init__()
self.disk_cache = disk_cache self.disk_cache = disk_cache
# load dataset from disk cache # load dataset from disk cache
...@@ -241,12 +242,10 @@ class AEVCacheLoader: ...@@ -241,12 +242,10 @@ class AEVCacheLoader:
self.dataset = pickle.load(f) self.dataset = pickle.load(f)
def __getitem__(self, index): def __getitem__(self, index):
if index >= self.__len__(): _, output = self.dataset.batches[index]
raise IndexError()
aev_path = os.path.join(self.disk_cache, str(index)) aev_path = os.path.join(self.disk_cache, str(index))
with open(aev_path, 'rb') as f: with open(aev_path, 'rb') as f:
species_aevs = pickle.load(f) species_aevs = pickle.load(f)
_, output = self.dataset.batches[index]
return species_aevs, output return species_aevs, output
def __len__(self): def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment