Commit 2b81ad8c authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Always pass transform and target_transform to abstract dataset (#1126)

* fixed call to the VisionDataset constructor

* change call from keyword arguments to positional

* changed order of arguments

* removed transforms argument once again

* Fixed call to constructor of parent class

* fixed LSUN

* fixed Caltech256
parent 2cae9509
......@@ -26,17 +26,16 @@ class Caltech101(VisionDataset):
downloaded again.
"""
def __init__(self, root, target_type="category",
transform=None, target_transform=None,
download=False):
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
def __init__(self, root, target_type="category", transform=None,
target_transform=None, download=False):
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
transform=transform,
target_transform=target_transform)
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
......@@ -143,13 +142,11 @@ class Caltech256(VisionDataset):
downloaded again.
"""
def __init__(self, root,
transform=None, target_transform=None,
download=False):
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
def __init__(self, root, transform=None, target_transform=None, download=False):
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'),
transform=transform,
target_transform=target_transform)
makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
......
......@@ -48,20 +48,16 @@ class CelebA(VisionDataset):
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
]
def __init__(self, root,
split="train",
target_type="attr",
transform=None, target_transform=None,
download=False):
def __init__(self, root, split="train", target_type="attr", transform=None,
target_transform=None, download=False):
import pandas
super(CelebA, self).__init__(root)
super(CelebA, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
......@@ -70,9 +66,6 @@ class CelebA(VisionDataset):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.transform = transform
self.target_transform = target_transform
if split.lower() == "train":
split = 0
elif split.lower() == "valid":
......
......@@ -52,13 +52,11 @@ class CIFAR10(VisionDataset):
'md5': '5ff9c542aee3614f3951f8cda6e48888',
}
def __init__(self, root, train=True,
transform=None, target_transform=None,
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
super(CIFAR10, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
super(CIFAR10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
......
......@@ -63,9 +63,8 @@ class Flickr8k(VisionDataset):
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
super(Flickr8k, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
super(Flickr8k, self).__init__(root, transform=transform,
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
......@@ -115,9 +114,8 @@ class Flickr30k(VisionDataset):
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
super(Flickr30k, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
super(Flickr30k, self).__init__(root, transform=transform,
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
......
......@@ -86,10 +86,10 @@ class DatasetFolder(VisionDataset):
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, loader, extensions=None, transform=None,
target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
......
......@@ -15,9 +15,8 @@ else:
class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None):
import lmdb
super(LSUNClass, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
super(LSUNClass, self).__init__(root, transform=transform,
target_transform=target_transform)
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False)
......@@ -68,11 +67,9 @@ class LSUN(VisionDataset):
target and transforms it.
"""
def __init__(self, root, classes='train',
transform=None, target_transform=None):
super(LSUN, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, classes='train', transform=None, target_transform=None):
super(LSUN, self).__init__(root, transform=transform,
target_transform=target_transform)
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
......
......@@ -57,10 +57,10 @@ class MNIST(VisionDataset):
warnings.warn("test_data has been renamed data")
return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(MNIST, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
super(MNIST, self).__init__(root, transform=transform,
target_transform=target_transform)
self.train = train # training set or test set
if download:
......
......@@ -28,12 +28,10 @@ class Omniglot(VisionDataset):
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
}
def __init__(self, root, background=True,
transform=None, target_transform=None,
def __init__(self, root, background=True, transform=None, target_transform=None,
download=False):
super(Omniglot, self).__init__(join(root, self.folder))
self.transform = transform
self.target_transform = target_transform
super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
target_transform=target_transform)
self.background = background
if download:
......
......@@ -65,8 +65,7 @@ class PhotoTour(VisionDataset):
matches_files = 'm50_100000_100000_0.txt'
def __init__(self, root, name, train=True, transform=None, download=False):
super(PhotoTour, self).__init__(root)
self.transform = transform
super(PhotoTour, self).__init__(root, transform=transform)
self.name = name
self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, '{}.zip'.format(name))
......
......@@ -24,11 +24,9 @@ class SBU(VisionDataset):
filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285'
def __init__(self, root, transform=None, target_transform=None,
download=True):
super(SBU, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, transform=None, target_transform=None, download=True):
super(SBU, self).__init__(root, transform=transform,
target_transform=target_transform)
if download:
self.download()
......
......@@ -24,11 +24,9 @@ class SEMEION(VisionDataset):
filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432'
def __init__(self, root, transform=None, target_transform=None,
download=True):
super(SEMEION, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, transform=None, target_transform=None, download=True):
super(SEMEION, self).__init__(root, transform=transform,
target_transform=target_transform)
if download:
self.download()
......
......@@ -46,15 +46,14 @@ class STL10(VisionDataset):
]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
def __init__(self, root, split='train', folds=None,
transform=None, target_transform=None, download=False):
def __init__(self, root, split='train', folds=None, transform=None,
target_transform=None, download=False):
if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits),
))
super(STL10, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
super(STL10, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split # train/test/unlabeled set
self.folds = folds # one of the 10 pre-defined folds or the full dataset
......
......@@ -39,11 +39,10 @@ class SVHN(VisionDataset):
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
super(SVHN, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
def __init__(self, root, split='train', transform=None, target_transform=None,
download=False):
super(SVHN, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split # training set or test set or extra set
if self.split not in self.split_list:
......
......@@ -37,8 +37,10 @@ class USPS(VisionDataset):
],
}
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
super(USPS, self).__init__(root, transform=transform, target_transform=target_transform)
def __init__(self, root, train=True, transform=None, target_transform=None,
download=False):
super(USPS, self).__init__(root, transform=transform,
target_transform=target_transform)
split = 'train' if train else 'test'
url, filename, checksum = self.split_list[split]
full_path = os.path.join(self.root, filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment