Commit 2b81ad8c authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Always pass transform and target_transform to abstract dataset (#1126)

* fixed call to the VisionDataset constructor

* change call from keyword arguments to positional

* changed order of arguments

* removed transforms argument once again

* Fixed call to constructor of parent class

* fixed LSUN

* fixed Caltech256
parent 2cae9509
...@@ -26,17 +26,16 @@ class Caltech101(VisionDataset): ...@@ -26,17 +26,16 @@ class Caltech101(VisionDataset):
downloaded again. downloaded again.
""" """
def __init__(self, root, target_type="category", def __init__(self, root, target_type="category", transform=None,
transform=None, target_transform=None, target_transform=None, download=False):
download=False): super(Caltech101, self).__init__(os.path.join(root, 'caltech101'),
super(Caltech101, self).__init__(os.path.join(root, 'caltech101')) transform=transform,
target_transform=target_transform)
makedir_exist_ok(self.root) makedir_exist_ok(self.root)
if isinstance(target_type, list): if isinstance(target_type, list):
self.target_type = target_type self.target_type = target_type
else: else:
self.target_type = [target_type] self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download: if download:
self.download() self.download()
...@@ -143,13 +142,11 @@ class Caltech256(VisionDataset): ...@@ -143,13 +142,11 @@ class Caltech256(VisionDataset):
downloaded again. downloaded again.
""" """
def __init__(self, root, def __init__(self, root, transform=None, target_transform=None, download=False):
transform=None, target_transform=None, super(Caltech256, self).__init__(os.path.join(root, 'caltech256'),
download=False): transform=transform,
super(Caltech256, self).__init__(os.path.join(root, 'caltech256')) target_transform=target_transform)
makedir_exist_ok(self.root) makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform
if download: if download:
self.download() self.download()
......
...@@ -48,20 +48,16 @@ class CelebA(VisionDataset): ...@@ -48,20 +48,16 @@ class CelebA(VisionDataset):
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
] ]
def __init__(self, root, def __init__(self, root, split="train", target_type="attr", transform=None,
split="train", target_transform=None, download=False):
target_type="attr",
transform=None, target_transform=None,
download=False):
import pandas import pandas
super(CelebA, self).__init__(root) super(CelebA, self).__init__(root, transform=transform,
target_transform=target_transform)
self.split = split self.split = split
if isinstance(target_type, list): if isinstance(target_type, list):
self.target_type = target_type self.target_type = target_type
else: else:
self.target_type = [target_type] self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download: if download:
self.download() self.download()
...@@ -70,9 +66,6 @@ class CelebA(VisionDataset): ...@@ -70,9 +66,6 @@ class CelebA(VisionDataset):
raise RuntimeError('Dataset not found or corrupted.' + raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it') ' You can use download=True to download it')
self.transform = transform
self.target_transform = target_transform
if split.lower() == "train": if split.lower() == "train":
split = 0 split = 0
elif split.lower() == "valid": elif split.lower() == "valid":
......
...@@ -52,13 +52,11 @@ class CIFAR10(VisionDataset): ...@@ -52,13 +52,11 @@ class CIFAR10(VisionDataset):
'md5': '5ff9c542aee3614f3951f8cda6e48888', 'md5': '5ff9c542aee3614f3951f8cda6e48888',
} }
def __init__(self, root, train=True, def __init__(self, root, train=True, transform=None, target_transform=None,
transform=None, target_transform=None,
download=False): download=False):
super(CIFAR10, self).__init__(root) super(CIFAR10, self).__init__(root, transform=transform,
self.transform = transform target_transform=target_transform)
self.target_transform = target_transform
self.train = train # training set or test set self.train = train # training set or test set
......
...@@ -63,9 +63,8 @@ class Flickr8k(VisionDataset): ...@@ -63,9 +63,8 @@ class Flickr8k(VisionDataset):
""" """
def __init__(self, root, ann_file, transform=None, target_transform=None): def __init__(self, root, ann_file, transform=None, target_transform=None):
super(Flickr8k, self).__init__(root) super(Flickr8k, self).__init__(root, transform=transform,
self.transform = transform target_transform=target_transform)
self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
...@@ -115,9 +114,8 @@ class Flickr30k(VisionDataset): ...@@ -115,9 +114,8 @@ class Flickr30k(VisionDataset):
""" """
def __init__(self, root, ann_file, transform=None, target_transform=None): def __init__(self, root, ann_file, transform=None, target_transform=None):
super(Flickr30k, self).__init__(root) super(Flickr30k, self).__init__(root, transform=transform,
self.transform = transform target_transform=target_transform)
self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
......
...@@ -86,10 +86,10 @@ class DatasetFolder(VisionDataset): ...@@ -86,10 +86,10 @@ class DatasetFolder(VisionDataset):
targets (list): The class_index value for each image in the dataset targets (list): The class_index value for each image in the dataset
""" """
def __init__(self, root, loader, extensions=None, transform=None, target_transform=None, is_valid_file=None): def __init__(self, root, loader, extensions=None, transform=None,
super(DatasetFolder, self).__init__(root) target_transform=None, is_valid_file=None):
self.transform = transform super(DatasetFolder, self).__init__(root, transform=transform,
self.target_transform = target_transform target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root) classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0: if len(samples) == 0:
......
...@@ -15,9 +15,8 @@ else: ...@@ -15,9 +15,8 @@ else:
class LSUNClass(VisionDataset): class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None): def __init__(self, root, transform=None, target_transform=None):
import lmdb import lmdb
super(LSUNClass, self).__init__(root) super(LSUNClass, self).__init__(root, transform=transform,
self.transform = transform target_transform=target_transform)
self.target_transform = target_transform
self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False, self.env = lmdb.open(root, max_readers=1, readonly=True, lock=False,
readahead=False, meminit=False) readahead=False, meminit=False)
...@@ -68,11 +67,9 @@ class LSUN(VisionDataset): ...@@ -68,11 +67,9 @@ class LSUN(VisionDataset):
target and transforms it. target and transforms it.
""" """
def __init__(self, root, classes='train', def __init__(self, root, classes='train', transform=None, target_transform=None):
transform=None, target_transform=None): super(LSUN, self).__init__(root, transform=transform,
super(LSUN, self).__init__(root) target_transform=target_transform)
self.transform = transform
self.target_transform = target_transform
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen', 'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower'] 'living_room', 'restaurant', 'tower']
......
...@@ -57,10 +57,10 @@ class MNIST(VisionDataset): ...@@ -57,10 +57,10 @@ class MNIST(VisionDataset):
warnings.warn("test_data has been renamed data") warnings.warn("test_data has been renamed data")
return self.data return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True, transform=None, target_transform=None,
super(MNIST, self).__init__(root) download=False):
self.transform = transform super(MNIST, self).__init__(root, transform=transform,
self.target_transform = target_transform target_transform=target_transform)
self.train = train # training set or test set self.train = train # training set or test set
if download: if download:
......
...@@ -28,12 +28,10 @@ class Omniglot(VisionDataset): ...@@ -28,12 +28,10 @@ class Omniglot(VisionDataset):
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
} }
def __init__(self, root, background=True, def __init__(self, root, background=True, transform=None, target_transform=None,
transform=None, target_transform=None,
download=False): download=False):
super(Omniglot, self).__init__(join(root, self.folder)) super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
self.transform = transform target_transform=target_transform)
self.target_transform = target_transform
self.background = background self.background = background
if download: if download:
......
...@@ -65,8 +65,7 @@ class PhotoTour(VisionDataset): ...@@ -65,8 +65,7 @@ class PhotoTour(VisionDataset):
matches_files = 'm50_100000_100000_0.txt' matches_files = 'm50_100000_100000_0.txt'
def __init__(self, root, name, train=True, transform=None, download=False): def __init__(self, root, name, train=True, transform=None, download=False):
super(PhotoTour, self).__init__(root) super(PhotoTour, self).__init__(root, transform=transform)
self.transform = transform
self.name = name self.name = name
self.data_dir = os.path.join(self.root, name) self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, '{}.zip'.format(name)) self.data_down = os.path.join(self.root, '{}.zip'.format(name))
......
...@@ -24,11 +24,9 @@ class SBU(VisionDataset): ...@@ -24,11 +24,9 @@ class SBU(VisionDataset):
filename = "SBUCaptionedPhotoDataset.tar.gz" filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285' md5_checksum = '9aec147b3488753cf758b4d493422285'
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None, download=True):
download=True): super(SBU, self).__init__(root, transform=transform,
super(SBU, self).__init__(root) target_transform=target_transform)
self.transform = transform
self.target_transform = target_transform
if download: if download:
self.download() self.download()
......
...@@ -24,11 +24,9 @@ class SEMEION(VisionDataset): ...@@ -24,11 +24,9 @@ class SEMEION(VisionDataset):
filename = "semeion.data" filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432' md5_checksum = 'cb545d371d2ce14ec121470795a77432'
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None, download=True):
download=True): super(SEMEION, self).__init__(root, transform=transform,
super(SEMEION, self).__init__(root) target_transform=target_transform)
self.transform = transform
self.target_transform = target_transform
if download: if download:
self.download() self.download()
......
...@@ -46,15 +46,14 @@ class STL10(VisionDataset): ...@@ -46,15 +46,14 @@ class STL10(VisionDataset):
] ]
splits = ('train', 'train+unlabeled', 'unlabeled', 'test') splits = ('train', 'train+unlabeled', 'unlabeled', 'test')
def __init__(self, root, split='train', folds=None, def __init__(self, root, split='train', folds=None, transform=None,
transform=None, target_transform=None, download=False): target_transform=None, download=False):
if split not in self.splits: if split not in self.splits:
raise ValueError('Split "{}" not found. Valid splits are: {}'.format( raise ValueError('Split "{}" not found. Valid splits are: {}'.format(
split, ', '.join(self.splits), split, ', '.join(self.splits),
)) ))
super(STL10, self).__init__(root) super(STL10, self).__init__(root, transform=transform,
self.transform = transform target_transform=target_transform)
self.target_transform = target_transform
self.split = split # train/test/unlabeled set self.split = split # train/test/unlabeled set
self.folds = folds # one of the 10 pre-defined folds or the full dataset self.folds = folds # one of the 10 pre-defined folds or the full dataset
......
...@@ -39,11 +39,10 @@ class SVHN(VisionDataset): ...@@ -39,11 +39,10 @@ class SVHN(VisionDataset):
'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 'extra': ["http://ufldl.stanford.edu/housenumbers/extra_32x32.mat",
"extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]} "extra_32x32.mat", "a93ce644f1a588dc4d68dda5feec44a7"]}
def __init__(self, root, split='train', def __init__(self, root, split='train', transform=None, target_transform=None,
transform=None, target_transform=None, download=False): download=False):
super(SVHN, self).__init__(root) super(SVHN, self).__init__(root, transform=transform,
self.transform = transform target_transform=target_transform)
self.target_transform = target_transform
self.split = split # training set or test set or extra set self.split = split # training set or test set or extra set
if self.split not in self.split_list: if self.split not in self.split_list:
......
...@@ -37,8 +37,10 @@ class USPS(VisionDataset): ...@@ -37,8 +37,10 @@ class USPS(VisionDataset):
], ],
} }
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True, transform=None, target_transform=None,
super(USPS, self).__init__(root, transform=transform, target_transform=target_transform) download=False):
super(USPS, self).__init__(root, transform=transform,
target_transform=target_transform)
split = 'train' if train else 'test' split = 'train' if train else 'test'
url, filename, checksum = self.split_list[split] url, filename, checksum = self.split_list[split]
full_path = os.path.join(self.root, filename) full_path = os.path.join(self.root, filename)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment