Commit bbaa1b0d authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

added support for VisionDataset (#838)

parent 8759f303
......@@ -2,19 +2,12 @@ from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import collections
import torch.utils.data as data
from .utils import download_url, check_integrity, makedir_exist_ok
from .vision import VisionDataset
from .utils import download_url, makedir_exist_ok
class Caltech101(data.Dataset):
class Caltech101(VisionDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
Args:
......@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
def __init__(self, root, target_type="category",
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech101")
super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
......@@ -138,19 +131,11 @@ class Caltech101(data.Dataset):
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
return "Target type: {target_type}".format(**self.__dict__)
class Caltech256(data.Dataset):
class Caltech256(VisionDataset):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
Args:
......@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
def __init__(self, root,
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech256")
super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform
......@@ -233,13 +218,3 @@ class Caltech256(data.Dataset):
# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
import torch
import torch.utils.data as data
import os
import PIL
from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity
class CelebA(data.Dataset):
class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
......@@ -53,7 +53,7 @@ class CelebA(data.Dataset):
transform=None, target_transform=None,
download=False):
import pandas
self.root = os.path.expanduser(root)
super(CelebA, self).__init__(root)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
......@@ -158,14 +158,6 @@ class CelebA(data.Dataset):
def __len__(self):
return len(self.attr)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
def extra_repr(self):
lines = ["Target type: {target_type}", "Split: {split}"]
return '\n'.join(lines).format(**self.__dict__)
......@@ -132,25 +132,8 @@ class ImageNet(ImageFolder):
def split_folder(self):
return os.path.join(self.root, self.split)
def __repr__(self):
head = "Dataset " + self.__class__.__name__
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += ["Split: {}".format(self.split)]
if hasattr(self, 'transform') and self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transforms: ")
if hasattr(self, 'target_transform') and self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
lines = [head] + [" " * 4 + line for line in body]
return '\n'.join(lines)
def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extra_repr(self):
return "Split: {split}".format(**self.__dict__)
def extract_tar(src, dest=None, gzip=None, delete=False):
......
import os
import torch.utils.data as data
from .vision import VisionDataset
import numpy as np
......@@ -8,7 +8,7 @@ from .utils import download_url
from .voc import download_extract
class SBDataset(data.Dataset):
class SBDataset(VisionDataset):
"""`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
......@@ -62,10 +62,11 @@ class SBDataset(data.Dataset):
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy")
super(SBDataset, self).__init__(root)
if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
self.root = os.path.expanduser(root)
self.xy_transform = xy_transform
self.image_set = image_set
self.mode = mode
......@@ -121,3 +122,7 @@ class SBDataset(data.Dataset):
def __len__(self):
return len(self.images)
def extra_repr(self):
lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment