# Copyright (c) DP Technology. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import lmdb import os import pickle import torch import numpy as np import collections from functools import lru_cache from . import data_utils import logging logger = logging.getLogger(__name__) class LMDBDataset: def __init__(self, db_path): self.db_path = db_path assert os.path.isfile(self.db_path), "{} not found".format( self.db_path ) env = self.connect_db(self.db_path) with env.begin() as txn: self._keys = list(txn.cursor().iternext(values=False)) def connect_db(self, lmdb_path, save_to_self=False): env = lmdb.open( lmdb_path, subdir=False, readonly=True, lock=False, readahead=False, meminit=False, max_readers=256, ) if not save_to_self: return env else: self.env = env def __len__(self): return len(self._keys) @lru_cache(maxsize=16) def __getitem__(self, idx): if not hasattr(self, 'env'): self.connect_db(self.db_path, save_to_self=True) datapoint_pickled = self.env.begin().get(self._keys[idx]) data = pickle.loads(datapoint_pickled) return data