Commit 73a29e02 authored by Jason Park's avatar Jason Park Committed by Soumith Chintala
Browse files

Update LSUN Dataset class (#452)

* Fix uninitialized instance variables

* Maintain consistency with other dataset classes

* Fix double assignment

* Fix initialization of self.classes
parent 00368603
......@@ -12,22 +12,23 @@ else:
class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
def __init__(self, root, transform=None, target_transform=None):
import lmdb
self.db_path = db_path
self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False,
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries']
cache_file = '_cache_' + db_path.replace('/', '_')
cache_file = '_cache_' + root.replace('/', '_')
if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb"))
else:
with self.env.begin(write=False) as txn:
self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb"))
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
img, target = None, None
......@@ -60,7 +61,7 @@ class LSUN(data.Dataset):
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.
Args:
db_path (string): Root directory for the database files.
root (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_train'].
transform (callable, optional): A function/transform that takes in an PIL image
......@@ -69,13 +70,16 @@ class LSUN(data.Dataset):
target and transforms it.
"""
def __init__(self, db_path, classes='train',
def __init__(self, root, classes='train',
transform=None, target_transform=None):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
self.db_path = db_path
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
if type(classes) == str and classes in dset_opts:
if classes == 'test':
classes = [classes]
......@@ -102,7 +106,7 @@ class LSUN(data.Dataset):
self.dbs = []
for c in self.classes:
self.dbs.append(LSUNClass(
db_path=db_path + '/' + c + '_lmdb',
root=root + '/' + c + '_lmdb',
transform=transform))
self.indices = []
......@@ -112,7 +116,6 @@ class LSUN(data.Dataset):
self.indices.append(count)
self.length = count
self.target_transform = target_transform
def __getitem__(self, index):
"""
......@@ -146,6 +149,7 @@ class LSUN(data.Dataset):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
fmt_str += ' Classes: {}\n'.format(self.classes)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment