Commit 6cabab3a authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Refactoring of the datasets (#749)

* introduced new super class for all vision datasets

* Removed root from repr if dataset has no root

* reverted some overly-ambitious autoformatting

* reverted some overly-ambitious autoformatting

* added split attribute to repr of STL10 dataset

* made Python2 friendly and more robust

* Fixed call of the superclass constructor

* moved transform and target_transform back to the base classes

* added check if transforms are present before printing to avoid setting them within the constructor

* added missing transforms and target_transforms to base classes

* fixed linter error
parent 50ea596e
...@@ -4,16 +4,17 @@ import os ...@@ -4,16 +4,17 @@ import os
import os.path import os.path
import numpy as np import numpy as np
import sys import sys
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
else: else:
import pickle import pickle
import torch.utils.data as data from .vision import VisionDataset
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
class CIFAR10(data.Dataset): class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset. """`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args: Args:
...@@ -54,9 +55,11 @@ class CIFAR10(data.Dataset): ...@@ -54,9 +55,11 @@ class CIFAR10(data.Dataset):
def __init__(self, root, train=True, def __init__(self, root, train=True,
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
self.root = os.path.expanduser(root)
super(CIFAR10, self).__init__(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.train = train # training set or test set self.train = train # training set or test set
if download: if download:
...@@ -153,17 +156,8 @@ class CIFAR10(data.Dataset): ...@@ -153,17 +156,8 @@ class CIFAR10(data.Dataset):
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar: with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root) tar.extractall(path=self.root)
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' return "Split: {}".format("Train" if self.train is True else "Test")
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
class CIFAR100(CIFAR10): class CIFAR100(CIFAR10):
......
...@@ -2,11 +2,11 @@ import json ...@@ -2,11 +2,11 @@ import json
import os import os
from collections import namedtuple from collections import namedtuple
import torch.utils.data as data from .vision import VisionDataset
from PIL import Image from PIL import Image
class Cityscapes(data.Dataset): class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset. """`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args: Args:
...@@ -93,12 +93,12 @@ class Cityscapes(data.Dataset): ...@@ -93,12 +93,12 @@ class Cityscapes(data.Dataset):
def __init__(self, root, split='train', mode='fine', target_type='instance', def __init__(self, root, split='train', mode='fine', target_type='instance',
transform=None, target_transform=None): transform=None, target_transform=None):
self.root = os.path.expanduser(root) super(Cityscapes, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse' self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit', split) self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
self.targets_dir = os.path.join(self.root, self.mode, split) self.targets_dir = os.path.join(self.root, self.mode, split)
self.transform = transform
self.target_transform = target_transform
self.target_type = target_type self.target_type = target_type
self.split = split self.split = split
self.images = [] self.images = []
...@@ -171,18 +171,9 @@ class Cityscapes(data.Dataset): ...@@ -171,18 +171,9 @@ class Cityscapes(data.Dataset):
def __len__(self): def __len__(self):
return len(self.images) return len(self.images)
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) return '\n'.join(lines).format(**self.__dict__)
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Mode: {}\n'.format(self.mode)
fmt_str += ' Type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def _load_json(self, path): def _load_json(self, path):
with open(path, 'r') as file: with open(path, 'r') as file:
......
import torch.utils.data as data from .vision import VisionDataset
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
class CocoCaptions(data.Dataset): class CocoCaptions(VisionDataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset. """`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
Args: Args:
...@@ -42,13 +42,14 @@ class CocoCaptions(data.Dataset): ...@@ -42,13 +42,14 @@ class CocoCaptions(data.Dataset):
u'A mountain view with a plume of smoke in the background'] u'A mountain view with a plume of smoke in the background']
""" """
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoCaptions, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = os.path.expanduser(root)
self.coco = COCO(annFile) self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys()) self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -79,7 +80,7 @@ class CocoCaptions(data.Dataset): ...@@ -79,7 +80,7 @@ class CocoCaptions(data.Dataset):
return len(self.ids) return len(self.ids)
class CocoDetection(data.Dataset): class CocoDetection(VisionDataset):
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset. """`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
Args: Args:
...@@ -92,12 +93,12 @@ class CocoDetection(data.Dataset): ...@@ -92,12 +93,12 @@ class CocoDetection(data.Dataset):
""" """
def __init__(self, root, annFile, transform=None, target_transform=None): def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
from pycocotools.coco import COCO from pycocotools.coco import COCO
self.root = root
self.coco = COCO(annFile) self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys()) self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -125,13 +126,3 @@ class CocoDetection(data.Dataset): ...@@ -125,13 +126,3 @@ class CocoDetection(data.Dataset):
def __len__(self): def __len__(self):
return len(self.ids) return len(self.ids)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
import torch import torch
import torch.utils.data as data from .vision import VisionDataset
from .. import transforms from .. import transforms
class FakeData(data.Dataset): class FakeData(VisionDataset):
"""A fake dataset that returns randomly generated images and returns them as PIL images """A fake dataset that returns randomly generated images and returns them as PIL images
Args: Args:
...@@ -21,6 +21,9 @@ class FakeData(data.Dataset): ...@@ -21,6 +21,9 @@ class FakeData(data.Dataset):
def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10, def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
transform=None, target_transform=None, random_offset=0): transform=None, target_transform=None, random_offset=0):
super(FakeData, self).__init__(None)
self.transform = transform
self.target_transform = target_transform
self.size = size self.size = size
self.num_classes = num_classes self.num_classes = num_classes
self.image_size = image_size self.image_size = image_size
...@@ -56,12 +59,3 @@ class FakeData(data.Dataset): ...@@ -56,12 +59,3 @@ class FakeData(data.Dataset):
def __len__(self): def __len__(self):
return self.size return self.size
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
...@@ -4,7 +4,7 @@ from six.moves import html_parser ...@@ -4,7 +4,7 @@ from six.moves import html_parser
import glob import glob
import os import os
import torch.utils.data as data from .vision import VisionDataset
class Flickr8kParser(html_parser.HTMLParser): class Flickr8kParser(html_parser.HTMLParser):
...@@ -50,7 +50,7 @@ class Flickr8kParser(html_parser.HTMLParser): ...@@ -50,7 +50,7 @@ class Flickr8kParser(html_parser.HTMLParser):
self.annotations[img_id].append(data.strip()) self.annotations[img_id].append(data.strip())
class Flickr8k(data.Dataset): class Flickr8k(VisionDataset):
"""`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset. """`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
Args: Args:
...@@ -61,11 +61,12 @@ class Flickr8k(data.Dataset): ...@@ -61,11 +61,12 @@ class Flickr8k(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
def __init__(self, root, ann_file, transform=None, target_transform=None): def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root) super(Flickr8k, self).__init__(root)
self.ann_file = os.path.expanduser(ann_file)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
parser = Flickr8kParser(self.root) parser = Flickr8kParser(self.root)
...@@ -101,7 +102,7 @@ class Flickr8k(data.Dataset): ...@@ -101,7 +102,7 @@ class Flickr8k(data.Dataset):
return len(self.ids) return len(self.ids)
class Flickr30k(data.Dataset): class Flickr30k(VisionDataset):
"""`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset. """`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.
Args: Args:
...@@ -112,11 +113,12 @@ class Flickr30k(data.Dataset): ...@@ -112,11 +113,12 @@ class Flickr30k(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the target_transform (callable, optional): A function/transform that takes in the
target and transforms it. target and transforms it.
""" """
def __init__(self, root, ann_file, transform=None, target_transform=None): def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root) super(Flickr30k, self).__init__(root)
self.ann_file = os.path.expanduser(ann_file)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict # Read annotations and store in a dict
self.annotations = defaultdict(list) self.annotations = defaultdict(list)
......
import torch.utils.data as data from .vision import VisionDataset
from PIL import Image from PIL import Image
...@@ -50,7 +50,7 @@ def make_dataset(dir, class_to_idx, extensions): ...@@ -50,7 +50,7 @@ def make_dataset(dir, class_to_idx, extensions):
return images return images
class DatasetFolder(data.Dataset): class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: :: """A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext root/class_x/xxx.ext
...@@ -79,13 +79,15 @@ class DatasetFolder(data.Dataset): ...@@ -79,13 +79,15 @@ class DatasetFolder(data.Dataset):
""" """
def __init__(self, root, loader, extensions, transform=None, target_transform=None): def __init__(self, root, loader, extensions, transform=None, target_transform=None):
super(DatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(root) classes, class_to_idx = self._find_classes(root)
samples = make_dataset(root, class_to_idx, extensions) samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0: if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions))) "Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader self.loader = loader
self.extensions = extensions self.extensions = extensions
...@@ -94,9 +96,6 @@ class DatasetFolder(data.Dataset): ...@@ -94,9 +96,6 @@ class DatasetFolder(data.Dataset):
self.samples = samples self.samples = samples
self.targets = [s[1] for s in samples] self.targets = [s[1] for s in samples]
self.transform = transform
self.target_transform = target_transform
def _find_classes(self, dir): def _find_classes(self, dir):
""" """
Finds the class folders in a dataset. Finds the class folders in a dataset.
...@@ -139,16 +138,6 @@ class DatasetFolder(data.Dataset): ...@@ -139,16 +138,6 @@ class DatasetFolder(data.Dataset):
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp') IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp')
...@@ -201,6 +190,7 @@ class ImageFolder(DatasetFolder): ...@@ -201,6 +190,7 @@ class ImageFolder(DatasetFolder):
class_to_idx (dict): Dict with items (class_name, class_index). class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples imgs (list): List of (image path, class_index) tuples
""" """
def __init__(self, root, transform=None, target_transform=None, def __init__(self, root, transform=None, target_transform=None,
loader=default_loader): loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
......
import torch.utils.data as data from .vision import VisionDataset
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
import six import six
import string import string
import sys import sys
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import cPickle as pickle import cPickle as pickle
else: else:
import pickle import pickle
class LSUNClass(data.Dataset): class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None): def __init__(self, root, transform=None, target_transform=None):
import lmdb import lmdb
self.root = os.path.expanduser(root) super(LSUNClass, self).__init__(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -52,11 +53,8 @@ class LSUNClass(data.Dataset): ...@@ -52,11 +53,8 @@ class LSUNClass(data.Dataset):
def __len__(self): def __len__(self):
return self.length return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.root + ')'
class LSUN(data.Dataset): class LSUN(VisionDataset):
""" """
`LSUN <http://lsun.cs.princeton.edu>`_ dataset. `LSUN <http://lsun.cs.princeton.edu>`_ dataset.
...@@ -72,13 +70,13 @@ class LSUN(data.Dataset): ...@@ -72,13 +70,13 @@ class LSUN(data.Dataset):
def __init__(self, root, classes='train', def __init__(self, root, classes='train',
transform=None, target_transform=None): transform=None, target_transform=None):
super(LSUN, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom', categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen', 'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower'] 'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test'] dset_opts = ['train', 'val', 'test']
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
if type(classes) == str and classes in dset_opts: if type(classes) == str and classes in dset_opts:
if classes == 'test': if classes == 'test':
...@@ -91,15 +89,15 @@ class LSUN(data.Dataset): ...@@ -91,15 +89,15 @@ class LSUN(data.Dataset):
c_short.pop(len(c_short) - 1) c_short.pop(len(c_short) - 1)
c_short = '_'.join(c_short) c_short = '_'.join(c_short)
if c_short not in categories: if c_short not in categories:
raise(ValueError('Unknown LSUN class: ' + c_short + '.' raise (ValueError('Unknown LSUN class: ' + c_short + '.'
'Options are: ' + str(categories))) 'Options are: ' + str(categories)))
c_short = c.split('_') c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1) c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts: if c_short not in dset_opts:
raise(ValueError('Unknown postfix: ' + c_short + '.' raise (ValueError('Unknown postfix: ' + c_short + '.'
'Options are: ' + str(dset_opts))) 'Options are: ' + str(dset_opts)))
else: else:
raise(ValueError('Unknown option for classes')) raise (ValueError('Unknown option for classes'))
self.classes = classes self.classes = classes
# for each class, create an LSUNClassDataset # for each class, create an LSUNClassDataset
...@@ -145,13 +143,5 @@ class LSUN(data.Dataset): ...@@ -145,13 +143,5 @@ class LSUN(data.Dataset):
def __len__(self): def __len__(self):
return self.length return self.length
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' return "Classes: {classes}".format(**self.__dict__)
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
fmt_str += ' Classes: {}\n'.format(self.classes)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
from __future__ import print_function from __future__ import print_function
from .vision import VisionDataset
import warnings import warnings
import torch.utils.data as data
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
...@@ -11,7 +11,7 @@ import codecs ...@@ -11,7 +11,7 @@ import codecs
from .utils import download_url, makedir_exist_ok from .utils import download_url, makedir_exist_ok
class MNIST(data.Dataset): class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset. """`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args: Args:
...@@ -59,7 +59,7 @@ class MNIST(data.Dataset): ...@@ -59,7 +59,7 @@ class MNIST(data.Dataset):
return self.data return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, download=False): def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root) super(MNIST, self).__init__(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.train = train # training set or test set self.train = train # training set or test set
...@@ -115,8 +115,10 @@ class MNIST(data.Dataset): ...@@ -115,8 +115,10 @@ class MNIST(data.Dataset):
return {_class: i for i, _class in enumerate(self.classes)} return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self): def _check_exists(self):
return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \ return (os.path.exists(os.path.join(self.processed_folder,
os.path.exists(os.path.join(self.processed_folder, self.test_file)) self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))
@staticmethod @staticmethod
def extract_gzip(gzip_path, remove_finished=False): def extract_gzip(gzip_path, remove_finished=False):
...@@ -161,17 +163,8 @@ class MNIST(data.Dataset): ...@@ -161,17 +163,8 @@ class MNIST(data.Dataset):
print('Done!') print('Done!')
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' return "Split: {}".format("Train" if self.train is True else "Test")
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
class FashionMNIST(MNIST): class FashionMNIST(MNIST):
......
...@@ -2,11 +2,11 @@ from __future__ import print_function ...@@ -2,11 +2,11 @@ from __future__ import print_function
from PIL import Image from PIL import Image
from os.path import join from os.path import join
import os import os
import torch.utils.data as data from .vision import VisionDataset
from .utils import download_url, check_integrity, list_dir, list_files from .utils import download_url, check_integrity, list_dir, list_files
class Omniglot(data.Dataset): class Omniglot(VisionDataset):
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset. """`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory root (string): Root directory of dataset where directory
...@@ -31,10 +31,10 @@ class Omniglot(data.Dataset): ...@@ -31,10 +31,10 @@ class Omniglot(data.Dataset):
def __init__(self, root, background=True, def __init__(self, root, background=True,
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
self.root = join(os.path.expanduser(root), self.folder) super(Omniglot, self).__init__(join(os.path.expanduser(root), self.folder))
self.background = background
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.background = background
if download: if download:
self.download() self.download()
......
...@@ -3,12 +3,12 @@ import numpy as np ...@@ -3,12 +3,12 @@ import numpy as np
from PIL import Image from PIL import Image
import torch import torch
import torch.utils.data as data from .vision import VisionDataset
from .utils import download_url from .utils import download_url
class PhotoTour(data.Dataset): class PhotoTour(VisionDataset):
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset. """`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
...@@ -65,14 +65,14 @@ class PhotoTour(data.Dataset): ...@@ -65,14 +65,14 @@ class PhotoTour(data.Dataset):
matches_files = 'm50_100000_100000_0.txt' matches_files = 'm50_100000_100000_0.txt'
def __init__(self, root, name, train=True, transform=None, download=False): def __init__(self, root, name, train=True, transform=None, download=False):
self.root = os.path.expanduser(root) super(PhotoTour, self).__init__(root)
self.transform = transform
self.name = name self.name = name
self.data_dir = os.path.join(self.root, name) self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, '{}.zip'.format(name)) self.data_down = os.path.join(self.root, '{}.zip'.format(name))
self.data_file = os.path.join(self.root, '{}.pt'.format(name)) self.data_file = os.path.join(self.root, '{}.pt'.format(name))
self.train = train self.train = train
self.transform = transform
self.mean = self.mean[name] self.mean = self.mean[name]
self.std = self.std[name] self.std = self.std[name]
...@@ -151,20 +151,14 @@ class PhotoTour(data.Dataset): ...@@ -151,20 +151,14 @@ class PhotoTour(data.Dataset):
with open(self.data_file, 'wb') as f: with open(self.data_file, 'wb') as f:
torch.save(dataset, f) torch.save(dataset, f)
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' return "Split: {}".format("Train" if self.train is True else "Test")
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def read_image_file(data_dir, image_ext, n): def read_image_file(data_dir, image_ext, n):
"""Return a Tensor containing the patches """Return a Tensor containing the patches
""" """
def PIL2array(_img): def PIL2array(_img):
"""Convert PIL image type to numpy 2D array """Convert PIL image type to numpy 2D array
""" """
...@@ -211,5 +205,6 @@ def read_matches_files(data_dir, matches_file): ...@@ -211,5 +205,6 @@ def read_matches_files(data_dir, matches_file):
with open(os.path.join(data_dir, matches_file), 'r') as f: with open(os.path.join(data_dir, matches_file), 'r') as f:
for line in f: for line in f:
line_split = line.split() line_split = line.split()
matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])]) matches.append([int(line_split[0]), int(line_split[3]),
int(line_split[1] == line_split[4])])
return torch.LongTensor(matches) return torch.LongTensor(matches)
...@@ -3,10 +3,10 @@ from six.moves import zip ...@@ -3,10 +3,10 @@ from six.moves import zip
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
import os import os
import torch.utils.data as data from .vision import VisionDataset
class SBU(data.Dataset): class SBU(VisionDataset):
"""`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset. """`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
Args: Args:
...@@ -24,8 +24,9 @@ class SBU(data.Dataset): ...@@ -24,8 +24,9 @@ class SBU(data.Dataset):
filename = "SBUCaptionedPhotoDataset.tar.gz" filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285' md5_checksum = '9aec147b3488753cf758b4d493422285'
def __init__(self, root, transform=None, target_transform=None, download=True): def __init__(self, root, transform=None, target_transform=None,
self.root = os.path.expanduser(root) download=True):
super(SBU, self).__init__(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
......
...@@ -3,11 +3,11 @@ from PIL import Image ...@@ -3,11 +3,11 @@ from PIL import Image
import os import os
import os.path import os.path
import numpy as np import numpy as np
import torch.utils.data as data from .vision import VisionDataset
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
class SEMEION(data.Dataset): class SEMEION(VisionDataset):
"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset. """`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
Args: Args:
root (string): Root directory of dataset where directory root (string): Root directory of dataset where directory
...@@ -24,8 +24,9 @@ class SEMEION(data.Dataset): ...@@ -24,8 +24,9 @@ class SEMEION(data.Dataset):
filename = "semeion.data" filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432' md5_checksum = 'cb545d371d2ce14ec121470795a77432'
def __init__(self, root, transform=None, target_transform=None, download=True): def __init__(self, root, transform=None, target_transform=None,
self.root = os.path.expanduser(root) download=True):
super(SEMEION, self).__init__(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -84,13 +85,3 @@ class SEMEION(data.Dataset): ...@@ -84,13 +85,3 @@ class SEMEION(data.Dataset):
root = self.root root = self.root
download_url(self.url, root, self.filename, self.md5_checksum) download_url(self.url, root, self.filename, self.md5_checksum)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
...@@ -129,13 +129,5 @@ class STL10(CIFAR10): ...@@ -129,13 +129,5 @@ class STL10(CIFAR10):
return images, labels return images, labels
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' return "Split: {split}".format(**self.__dict__)
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
from __future__ import print_function from __future__ import print_function
import torch.utils.data as data from .vision import VisionDataset
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from .utils import download_url, check_integrity from .utils import download_url, check_integrity
class SVHN(data.Dataset): class SVHN(VisionDataset):
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset. """`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset, Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
...@@ -41,7 +41,7 @@ class SVHN(data.Dataset): ...@@ -41,7 +41,7 @@ class SVHN(data.Dataset):
def __init__(self, root, split='train', def __init__(self, root, split='train',
transform=None, target_transform=None, download=False): transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root) super(SVHN, self).__init__(root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
self.split = split # training set or test set or extra set self.split = split # training set or test set or extra set
...@@ -116,13 +116,5 @@ class SVHN(data.Dataset): ...@@ -116,13 +116,5 @@ class SVHN(data.Dataset):
md5 = self.split_list[self.split][2] md5 = self.split_list[self.split][2]
download_url(self.url, self.root, self.filename, md5) download_url(self.url, self.root, self.filename, md5)
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' return "Split: {split}".format(**self.__dict__)
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
import os
import torch
import torch.utils.data as data
class VisionDataset(data.Dataset):
_repr_indent = 4
def __init__(self, root):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, 'transform') and self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transforms: ")
if hasattr(self, 'target_transform') and self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self):
return ""
...@@ -2,7 +2,8 @@ import os ...@@ -2,7 +2,8 @@ import os
import sys import sys
import tarfile import tarfile
import collections import collections
import torch.utils.data as data from .vision import VisionDataset
if sys.version_info[0] == 2: if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET import xml.etree.cElementTree as ET
else: else:
...@@ -51,7 +52,7 @@ DATASET_YEAR_DICT = { ...@@ -51,7 +52,7 @@ DATASET_YEAR_DICT = {
} }
class VOCSegmentation(data.Dataset): class VOCSegmentation(VisionDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args: Args:
...@@ -74,13 +75,13 @@ class VOCSegmentation(data.Dataset): ...@@ -74,13 +75,13 @@ class VOCSegmentation(data.Dataset):
download=False, download=False,
transform=None, transform=None,
target_transform=None): target_transform=None):
self.root = os.path.expanduser(root) super(VOCSegmentation, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.year = year self.year = year
self.url = DATASET_YEAR_DICT[year]['url'] self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename'] self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5'] self.md5 = DATASET_YEAR_DICT[year]['md5']
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]['base_dir'] base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir) voc_root = os.path.join(self.root, base_dir)
...@@ -133,7 +134,7 @@ class VOCSegmentation(data.Dataset): ...@@ -133,7 +134,7 @@ class VOCSegmentation(data.Dataset):
return len(self.images) return len(self.images)
class VOCDetection(data.Dataset): class VOCDetection(VisionDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset. """`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
Args: Args:
...@@ -157,13 +158,13 @@ class VOCDetection(data.Dataset): ...@@ -157,13 +158,13 @@ class VOCDetection(data.Dataset):
download=False, download=False,
transform=None, transform=None,
target_transform=None): target_transform=None):
self.root = os.path.expanduser(root) super(VOCDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.year = year self.year = year
self.url = DATASET_YEAR_DICT[year]['url'] self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename'] self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5'] self.md5 = DATASET_YEAR_DICT[year]['md5']
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]['base_dir'] base_dir = DATASET_YEAR_DICT[year]['base_dir']
...@@ -228,8 +229,8 @@ class VOCDetection(data.Dataset): ...@@ -228,8 +229,8 @@ class VOCDetection(data.Dataset):
def_dic[ind].append(v) def_dic[ind].append(v)
voc_dict = { voc_dict = {
node.tag: node.tag:
{ind: v[0] if len(v) == 1 else v {ind: v[0] if len(v) == 1 else v
for ind, v in def_dic.items()} for ind, v in def_dic.items()}
} }
if node.text: if node.text:
text = node.text.strip() text = node.text.strip()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment