You need to sign in or sign up before continuing.
Commit 73a29e02 authored by Jason Park's avatar Jason Park Committed by Soumith Chintala
Browse files

Update LSUN Dataset class (#452)

* Fix uninitialized instance variables

* Maintain consistency with other dataset classes

* Fix double assignment

* Fix initialization of self.classes
parent 00368603
...@@ -12,22 +12,23 @@ else: ...@@ -12,22 +12,23 @@ else:
class LSUNClass(data.Dataset): class LSUNClass(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None): def __init__(self, root, transform=None, target_transform=None):
import lmdb import lmdb
self.db_path = db_path self.root = os.path.expanduser(root)
self.env = lmdb.open(db_path, max_readers=1, readonly=True, lock=False, self.transform = transform
self.target_transform = target_transform
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False) readahead=False, meminit=False)
with self.env.begin(write=False) as txn: with self.env.begin(write=False) as txn:
self.length = txn.stat()['entries'] self.length = txn.stat()['entries']
cache_file = '_cache_' + db_path.replace('/', '_') cache_file = '_cache_' + root.replace('/', '_')
if os.path.isfile(cache_file): if os.path.isfile(cache_file):
self.keys = pickle.load(open(cache_file, "rb")) self.keys = pickle.load(open(cache_file, "rb"))
else: else:
with self.env.begin(write=False) as txn: with self.env.begin(write=False) as txn:
self.keys = [key for key, _ in txn.cursor()] self.keys = [key for key, _ in txn.cursor()]
pickle.dump(self.keys, open(cache_file, "wb")) pickle.dump(self.keys, open(cache_file, "wb"))
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index): def __getitem__(self, index):
img, target = None, None img, target = None, None
...@@ -60,7 +61,7 @@ class LSUN(data.Dataset): ...@@ -60,7 +61,7 @@ class LSUN(data.Dataset):
`LSUN <http://lsun.cs.princeton.edu>`_ dataset. `LSUN <http://lsun.cs.princeton.edu>`_ dataset.
Args: Args:
db_path (string): Root directory for the database files. root (string): Root directory for the database files.
classes (string or list): One of {'train', 'val', 'test'} or a list of classes (string or list): One of {'train', 'val', 'test'} or a list of
categories to load. e,g. ['bedroom_train', 'church_train']. categories to load. e,g. ['bedroom_train', 'church_train'].
transform (callable, optional): A function/transform that takes in an PIL image transform (callable, optional): A function/transform that takes in an PIL image
...@@ -69,13 +70,16 @@ class LSUN(data.Dataset): ...@@ -69,13 +70,16 @@ class LSUN(data.Dataset):
target and transforms it. target and transforms it.
""" """
def __init__(self, db_path, classes='train', def __init__(self, root, classes='train',
transform=None, target_transform=None): transform=None, target_transform=None):
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen', 'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower'] 'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test'] dset_opts = ['train', 'val', 'test']
self.db_path = db_path self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
if type(classes) == str and classes in dset_opts: if type(classes) == str and classes in dset_opts:
if classes == 'test': if classes == 'test':
classes = [classes] classes = [classes]
...@@ -102,7 +106,7 @@ class LSUN(data.Dataset): ...@@ -102,7 +106,7 @@ class LSUN(data.Dataset):
self.dbs = [] self.dbs = []
for c in self.classes: for c in self.classes:
self.dbs.append(LSUNClass( self.dbs.append(LSUNClass(
db_path=db_path + '/' + c + '_lmdb', root=root + '/' + c + '_lmdb',
transform=transform)) transform=transform))
self.indices = [] self.indices = []
...@@ -112,7 +116,6 @@ class LSUN(data.Dataset): ...@@ -112,7 +116,6 @@ class LSUN(data.Dataset):
self.indices.append(count) self.indices.append(count)
self.length = count self.length = count
self.target_transform = target_transform
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -146,6 +149,7 @@ class LSUN(data.Dataset): ...@@ -146,6 +149,7 @@ class LSUN(data.Dataset):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root) fmt_str += ' Root Location: {}\n'.format(self.root)
fmt_str += ' Classes: {}\n'.format(self.classes)
tmp = ' Transforms (if any): ' tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): ' tmp = ' Target Transforms (if any): '
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment