Commit 1b9304c0 authored by soumith's avatar soumith
Browse files

lsun test classes fix

parent d8cb7f0a
......@@ -70,7 +70,10 @@ class LSUN(data.Dataset):
dset_opts = ['train', 'val', 'test']
self.db_path = db_path
if type(classes) == str and classes in dset_opts:
classes = [c + '_' + classes for c in categories]
if classes == 'test':
classes = [classes]
else:
classes = [c + '_' + classes for c in categories]
if type(classes) == list:
for c in classes:
c_short = c.split('_')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment