Commit 6cabab3a authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

Refactoring of the datasets (#749)

* introduced new super class for all vision datasets

* Removed root from repr if dataset has no root

* reverted some overly-ambitious autoformatting

* reverted some overly-ambitious autoformatting

* added split attribute to repr of STL10 dataset

* made Python2 friendly and more robust

* Fixed call of the superclass constructor

* moved transform and target_transform back to the base classes

* added check if transforms are present before printing to avoid setting them within the constructor

* added missing transforms and target_transforms to base classes

* fixed linter error
parent 50ea596e
......@@ -4,16 +4,17 @@ import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import torch.utils.data as data
from .vision import VisionDataset
from .utils import download_url, check_integrity
class CIFAR10(data.Dataset):
class CIFAR10(VisionDataset):
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
Args:
......@@ -54,9 +55,11 @@ class CIFAR10(data.Dataset):
def __init__(self, root, train=True,
transform=None, target_transform=None,
download=False):
self.root = os.path.expanduser(root)
super(CIFAR10, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
if download:
......@@ -153,17 +156,8 @@ class CIFAR10(data.Dataset):
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
class CIFAR100(CIFAR10):
......
......@@ -2,11 +2,11 @@ import json
import os
from collections import namedtuple
import torch.utils.data as data
from .vision import VisionDataset
from PIL import Image
class Cityscapes(data.Dataset):
class Cityscapes(VisionDataset):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
......@@ -93,12 +93,12 @@ class Cityscapes(data.Dataset):
def __init__(self, root, split='train', mode='fine', target_type='instance',
transform=None, target_transform=None):
self.root = os.path.expanduser(root)
super(Cityscapes, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.mode = 'gtFine' if mode == 'fine' else 'gtCoarse'
self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
self.targets_dir = os.path.join(self.root, self.mode, split)
self.transform = transform
self.target_transform = target_transform
self.target_type = target_type
self.split = split
self.images = []
......@@ -171,18 +171,9 @@ class Cityscapes(data.Dataset):
def __len__(self):
return len(self.images)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Mode: {}\n'.format(self.mode)
fmt_str += ' Type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
lines = ["Split: {split}", "Mode: {mode}", "Type: {target_type}"]
return '\n'.join(lines).format(**self.__dict__)
def _load_json(self, path):
with open(path, 'r') as file:
......
import torch.utils.data as data
from .vision import VisionDataset
from PIL import Image
import os
import os.path
class CocoCaptions(data.Dataset):
class CocoCaptions(VisionDataset):
"""`MS Coco Captions <http://mscoco.org/dataset/#captions-challenge2015>`_ Dataset.
Args:
......@@ -42,13 +42,14 @@ class CocoCaptions(data.Dataset):
u'A mountain view with a plume of smoke in the background']
"""
def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoCaptions, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
from pycocotools.coco import COCO
self.root = os.path.expanduser(root)
self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
......@@ -79,7 +80,7 @@ class CocoCaptions(data.Dataset):
return len(self.ids)
class CocoDetection(data.Dataset):
class CocoDetection(VisionDataset):
"""`MS Coco Detection <http://mscoco.org/dataset/#detections-challenge2016>`_ Dataset.
Args:
......@@ -92,12 +93,12 @@ class CocoDetection(data.Dataset):
"""
def __init__(self, root, annFile, transform=None, target_transform=None):
super(CocoDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
from pycocotools.coco import COCO
self.root = root
self.coco = COCO(annFile)
self.ids = list(self.coco.imgs.keys())
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
"""
......@@ -125,13 +126,3 @@ class CocoDetection(data.Dataset):
def __len__(self):
return len(self.ids)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
import torch
import torch.utils.data as data
from .vision import VisionDataset
from .. import transforms
class FakeData(data.Dataset):
class FakeData(VisionDataset):
"""A fake dataset that returns randomly generated images and returns them as PIL images
Args:
......@@ -21,6 +21,9 @@ class FakeData(data.Dataset):
def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10,
transform=None, target_transform=None, random_offset=0):
super(FakeData, self).__init__(None)
self.transform = transform
self.target_transform = target_transform
self.size = size
self.num_classes = num_classes
self.image_size = image_size
......@@ -56,12 +59,3 @@ class FakeData(data.Dataset):
def __len__(self):
return self.size
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
......@@ -4,7 +4,7 @@ from six.moves import html_parser
import glob
import os
import torch.utils.data as data
from .vision import VisionDataset
class Flickr8kParser(html_parser.HTMLParser):
......@@ -50,7 +50,7 @@ class Flickr8kParser(html_parser.HTMLParser):
self.annotations[img_id].append(data.strip())
class Flickr8k(data.Dataset):
class Flickr8k(VisionDataset):
"""`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
Args:
......@@ -61,11 +61,12 @@ class Flickr8k(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
super(Flickr8k, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
parser = Flickr8kParser(self.root)
......@@ -101,7 +102,7 @@ class Flickr8k(data.Dataset):
return len(self.ids)
class Flickr30k(data.Dataset):
class Flickr30k(VisionDataset):
"""`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.
Args:
......@@ -112,11 +113,12 @@ class Flickr30k(data.Dataset):
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
super(Flickr30k, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.ann_file = os.path.expanduser(ann_file)
# Read annotations and store in a dict
self.annotations = defaultdict(list)
......
import torch.utils.data as data
from .vision import VisionDataset
from PIL import Image
......@@ -50,7 +50,7 @@ def make_dataset(dir, class_to_idx, extensions):
return images
class DatasetFolder(data.Dataset):
class DatasetFolder(VisionDataset):
"""A generic data loader where the samples are arranged in this way: ::
root/class_x/xxx.ext
......@@ -79,13 +79,15 @@ class DatasetFolder(data.Dataset):
"""
def __init__(self, root, loader, extensions, transform=None, target_transform=None):
super(DatasetFolder, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
classes, class_to_idx = self._find_classes(root)
samples = make_dataset(root, class_to_idx, extensions)
if len(samples) == 0:
raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n"
raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n"
"Supported extensions are: " + ",".join(extensions)))
self.root = root
self.loader = loader
self.extensions = extensions
......@@ -94,9 +96,6 @@ class DatasetFolder(data.Dataset):
self.samples = samples
self.targets = [s[1] for s in samples]
self.transform = transform
self.target_transform = target_transform
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
......@@ -139,16 +138,6 @@ class DatasetFolder(data.Dataset):
def __len__(self):
return len(self.samples)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', 'webp')
......@@ -201,6 +190,7 @@ class ImageFolder(DatasetFolder):
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader):
super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS,
......
import torch.utils.data as data
from .vision import VisionDataset
from PIL import Image
import os
import os.path
import six
import string
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
class LSUNClass(data.Dataset):
class LSUNClass(VisionDataset):
def __init__(self, root, transform=None, target_transform=None):
import lmdb
self.root = os.path.expanduser(root)
super(LSUNClass, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
......@@ -52,11 +53,8 @@ class LSUNClass(data.Dataset):
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.root + ')'
class LSUN(data.Dataset):
class LSUN(VisionDataset):
"""
`LSUN <http://lsun.cs.princeton.edu>`_ dataset.
......@@ -72,13 +70,13 @@ class LSUN(data.Dataset):
def __init__(self, root, classes='train',
transform=None, target_transform=None):
super(LSUN, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
categories = ['bedroom', 'bridge', 'church_outdoor', 'classroom',
'conference_room', 'dining_room', 'kitchen',
'living_room', 'restaurant', 'tower']
dset_opts = ['train', 'val', 'test']
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
if type(classes) == str and classes in dset_opts:
if classes == 'test':
......@@ -91,15 +89,15 @@ class LSUN(data.Dataset):
c_short.pop(len(c_short) - 1)
c_short = '_'.join(c_short)
if c_short not in categories:
raise(ValueError('Unknown LSUN class: ' + c_short + '.'
raise (ValueError('Unknown LSUN class: ' + c_short + '.'
'Options are: ' + str(categories)))
c_short = c.split('_')
c_short = c_short.pop(len(c_short) - 1)
if c_short not in dset_opts:
raise(ValueError('Unknown postfix: ' + c_short + '.'
raise (ValueError('Unknown postfix: ' + c_short + '.'
'Options are: ' + str(dset_opts)))
else:
raise(ValueError('Unknown option for classes'))
raise (ValueError('Unknown option for classes'))
self.classes = classes
# for each class, create an LSUNClassDataset
......@@ -145,13 +143,5 @@ class LSUN(data.Dataset):
def __len__(self):
return self.length
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
fmt_str += ' Classes: {}\n'.format(self.classes)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Classes: {classes}".format(**self.__dict__)
from __future__ import print_function
from .vision import VisionDataset
import warnings
import torch.utils.data as data
from PIL import Image
import os
import os.path
......@@ -11,7 +11,7 @@ import codecs
from .utils import download_url, makedir_exist_ok
class MNIST(data.Dataset):
class MNIST(VisionDataset):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset.
Args:
......@@ -59,7 +59,7 @@ class MNIST(data.Dataset):
return self.data
def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
super(MNIST, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
......@@ -115,8 +115,10 @@ class MNIST(data.Dataset):
return {_class: i for i, _class in enumerate(self.classes)}
def _check_exists(self):
return os.path.exists(os.path.join(self.processed_folder, self.training_file)) and \
os.path.exists(os.path.join(self.processed_folder, self.test_file))
return (os.path.exists(os.path.join(self.processed_folder,
self.training_file)) and
os.path.exists(os.path.join(self.processed_folder,
self.test_file)))
@staticmethod
def extract_gzip(gzip_path, remove_finished=False):
......@@ -161,17 +163,8 @@ class MNIST(data.Dataset):
print('Done!')
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
class FashionMNIST(MNIST):
......
......@@ -2,11 +2,11 @@ from __future__ import print_function
from PIL import Image
from os.path import join
import os
import torch.utils.data as data
from .vision import VisionDataset
from .utils import download_url, check_integrity, list_dir, list_files
class Omniglot(data.Dataset):
class Omniglot(VisionDataset):
"""`Omniglot <https://github.com/brendenlake/omniglot>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
......@@ -31,10 +31,10 @@ class Omniglot(data.Dataset):
def __init__(self, root, background=True,
transform=None, target_transform=None,
download=False):
self.root = join(os.path.expanduser(root), self.folder)
self.background = background
super(Omniglot, self).__init__(join(os.path.expanduser(root), self.folder))
self.transform = transform
self.target_transform = target_transform
self.background = background
if download:
self.download()
......
......@@ -3,12 +3,12 @@ import numpy as np
from PIL import Image
import torch
import torch.utils.data as data
from .vision import VisionDataset
from .utils import download_url
class PhotoTour(data.Dataset):
class PhotoTour(VisionDataset):
"""`Learning Local Image Descriptors Data <http://phototour.cs.washington.edu/patches/default.htm>`_ Dataset.
......@@ -65,14 +65,14 @@ class PhotoTour(data.Dataset):
matches_files = 'm50_100000_100000_0.txt'
def __init__(self, root, name, train=True, transform=None, download=False):
self.root = os.path.expanduser(root)
super(PhotoTour, self).__init__(root)
self.transform = transform
self.name = name
self.data_dir = os.path.join(self.root, name)
self.data_down = os.path.join(self.root, '{}.zip'.format(name))
self.data_file = os.path.join(self.root, '{}.pt'.format(name))
self.train = train
self.transform = transform
self.mean = self.mean[name]
self.std = self.std[name]
......@@ -151,20 +151,14 @@ class PhotoTour(data.Dataset):
with open(self.data_file, 'wb') as f:
torch.save(dataset, f)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
tmp = 'train' if self.train is True else 'test'
fmt_str += ' Split: {}\n'.format(tmp)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Split: {}".format("Train" if self.train is True else "Test")
def read_image_file(data_dir, image_ext, n):
"""Return a Tensor containing the patches
"""
def PIL2array(_img):
"""Convert PIL image type to numpy 2D array
"""
......@@ -211,5 +205,6 @@ def read_matches_files(data_dir, matches_file):
with open(os.path.join(data_dir, matches_file), 'r') as f:
for line in f:
line_split = line.split()
matches.append([int(line_split[0]), int(line_split[3]), int(line_split[1] == line_split[4])])
matches.append([int(line_split[0]), int(line_split[3]),
int(line_split[1] == line_split[4])])
return torch.LongTensor(matches)
......@@ -3,10 +3,10 @@ from six.moves import zip
from .utils import download_url, check_integrity
import os
import torch.utils.data as data
from .vision import VisionDataset
class SBU(data.Dataset):
class SBU(VisionDataset):
"""`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
Args:
......@@ -24,8 +24,9 @@ class SBU(data.Dataset):
filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285'
def __init__(self, root, transform=None, target_transform=None, download=True):
self.root = os.path.expanduser(root)
def __init__(self, root, transform=None, target_transform=None,
download=True):
super(SBU, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
......
......@@ -3,11 +3,11 @@ from PIL import Image
import os
import os.path
import numpy as np
import torch.utils.data as data
from .vision import VisionDataset
from .utils import download_url, check_integrity
class SEMEION(data.Dataset):
class SEMEION(VisionDataset):
"""`SEMEION <http://archive.ics.uci.edu/ml/datasets/semeion+handwritten+digit>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
......@@ -24,8 +24,9 @@ class SEMEION(data.Dataset):
filename = "semeion.data"
md5_checksum = 'cb545d371d2ce14ec121470795a77432'
def __init__(self, root, transform=None, target_transform=None, download=True):
self.root = os.path.expanduser(root)
def __init__(self, root, transform=None, target_transform=None,
download=True):
super(SEMEION, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
......@@ -84,13 +85,3 @@ class SEMEION(data.Dataset):
root = self.root
download_url(self.url, root, self.filename, self.md5_checksum)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
......@@ -129,13 +129,5 @@ class STL10(CIFAR10):
return images, labels
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
from __future__ import print_function
import torch.utils.data as data
from .vision import VisionDataset
from PIL import Image
import os
import os.path
......@@ -7,7 +7,7 @@ import numpy as np
from .utils import download_url, check_integrity
class SVHN(data.Dataset):
class SVHN(VisionDataset):
"""`SVHN <http://ufldl.stanford.edu/housenumbers/>`_ Dataset.
Note: The SVHN dataset assigns the label `10` to the digit `0`. However, in this Dataset,
we assign the label `0` to the digit `0` to be compatible with PyTorch loss functions which
......@@ -41,7 +41,7 @@ class SVHN(data.Dataset):
def __init__(self, root, split='train',
transform=None, target_transform=None, download=False):
self.root = os.path.expanduser(root)
super(SVHN, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.split = split # training set or test set or extra set
......@@ -116,13 +116,5 @@ class SVHN(data.Dataset):
md5 = self.split_list[self.split][2]
download_url(self.url, self.root, self.filename, md5)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
import os
import torch
import torch.utils.data as data
class VisionDataset(data.Dataset):
_repr_indent = 4
def __init__(self, root):
if isinstance(root, torch._six.string_classes):
root = os.path.expanduser(root)
self.root = root
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += self.extra_repr().splitlines()
if hasattr(self, 'transform') and self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transforms: ")
if hasattr(self, 'target_transform') and self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
lines = [head] + [" " * self._repr_indent + line for line in body]
return '\n'.join(lines)
def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self):
return ""
......@@ -2,7 +2,8 @@ import os
import sys
import tarfile
import collections
import torch.utils.data as data
from .vision import VisionDataset
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
......@@ -51,7 +52,7 @@ DATASET_YEAR_DICT = {
}
class VOCSegmentation(data.Dataset):
class VOCSegmentation(VisionDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
......@@ -74,13 +75,13 @@ class VOCSegmentation(data.Dataset):
download=False,
transform=None,
target_transform=None):
self.root = os.path.expanduser(root)
super(VOCSegmentation, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
......@@ -133,7 +134,7 @@ class VOCSegmentation(data.Dataset):
return len(self.images)
class VOCDetection(data.Dataset):
class VOCDetection(VisionDataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Detection Dataset.
Args:
......@@ -157,13 +158,13 @@ class VOCDetection(data.Dataset):
download=False,
transform=None,
target_transform=None):
self.root = os.path.expanduser(root)
super(VOCDetection, self).__init__(root)
self.transform = transform
self.target_transform = target_transform
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.transform = transform
self.target_transform = target_transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]['base_dir']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment