Commit bbaa1b0d authored by Philip Meier's avatar Philip Meier Committed by Francisco Massa
Browse files

added support for VisionDataset (#838)

parent 8759f303
...@@ -2,19 +2,12 @@ from __future__ import print_function ...@@ -2,19 +2,12 @@ from __future__ import print_function
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import collections
import torch.utils.data as data from .vision import VisionDataset
from .utils import download_url, check_integrity, makedir_exist_ok from .utils import download_url, makedir_exist_ok
class Caltech101(data.Dataset): class Caltech101(VisionDataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset. """`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
Args: Args:
...@@ -36,7 +29,7 @@ class Caltech101(data.Dataset): ...@@ -36,7 +29,7 @@ class Caltech101(data.Dataset):
def __init__(self, root, target_type="category", def __init__(self, root, target_type="category",
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech101") super(Caltech101, self).__init__(os.path.join(root, 'caltech101'))
makedir_exist_ok(self.root) makedir_exist_ok(self.root)
if isinstance(target_type, list): if isinstance(target_type, list):
self.target_type = target_type self.target_type = target_type
...@@ -138,19 +131,11 @@ class Caltech101(data.Dataset): ...@@ -138,19 +131,11 @@ class Caltech101(data.Dataset):
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar: with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root) tar.extractall(path=self.root)
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' return "Target type: {target_type}".format(**self.__dict__)
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
class Caltech256(data.Dataset): class Caltech256(VisionDataset):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset. """`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
Args: Args:
...@@ -168,7 +153,7 @@ class Caltech256(data.Dataset): ...@@ -168,7 +153,7 @@ class Caltech256(data.Dataset):
def __init__(self, root, def __init__(self, root,
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech256") super(Caltech256, self).__init__(os.path.join(root, 'caltech256'))
makedir_exist_ok(self.root) makedir_exist_ok(self.root)
self.transform = transform self.transform = transform
self.target_transform = target_transform self.target_transform = target_transform
...@@ -233,13 +218,3 @@ class Caltech256(data.Dataset): ...@@ -233,13 +218,3 @@ class Caltech256(data.Dataset):
# extract file # extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar: with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root) tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
import torch import torch
import torch.utils.data as data
import os import os
import PIL import PIL
from .vision import VisionDataset
from .utils import download_file_from_google_drive, check_integrity from .utils import download_file_from_google_drive, check_integrity
class CelebA(data.Dataset): class CelebA(VisionDataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset. """`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args: Args:
...@@ -53,7 +53,7 @@ class CelebA(data.Dataset): ...@@ -53,7 +53,7 @@ class CelebA(data.Dataset):
transform=None, target_transform=None, transform=None, target_transform=None,
download=False): download=False):
import pandas import pandas
self.root = os.path.expanduser(root) super(CelebA, self).__init__(root)
self.split = split self.split = split
if isinstance(target_type, list): if isinstance(target_type, list):
self.target_type = target_type self.target_type = target_type
...@@ -158,14 +158,6 @@ class CelebA(data.Dataset): ...@@ -158,14 +158,6 @@ class CelebA(data.Dataset):
def __len__(self): def __len__(self):
return len(self.attr) return len(self.attr)
def __repr__(self): def extra_repr(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' lines = ["Target type: {target_type}", "Split: {split}"]
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) return '\n'.join(lines).format(**self.__dict__)
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
...@@ -132,25 +132,8 @@ class ImageNet(ImageFolder): ...@@ -132,25 +132,8 @@ class ImageNet(ImageFolder):
def split_folder(self): def split_folder(self):
return os.path.join(self.root, self.split) return os.path.join(self.root, self.split)
def __repr__(self): def extra_repr(self):
head = "Dataset " + self.__class__.__name__ return "Split: {split}".format(**self.__dict__)
body = ["Number of datapoints: {}".format(self.__len__())]
if self.root is not None:
body.append("Root location: {}".format(self.root))
body += ["Split: {}".format(self.split)]
if hasattr(self, 'transform') and self.transform is not None:
body += self._format_transform_repr(self.transform,
"Transforms: ")
if hasattr(self, 'target_transform') and self.target_transform is not None:
body += self._format_transform_repr(self.target_transform,
"Target transforms: ")
lines = [head] + [" " * 4 + line for line in body]
return '\n'.join(lines)
def _format_transform_repr(self, transform, head):
lines = transform.__repr__().splitlines()
return (["{}{}".format(head, lines[0])] +
["{}{}".format(" " * len(head), line) for line in lines[1:]])
def extract_tar(src, dest=None, gzip=None, delete=False): def extract_tar(src, dest=None, gzip=None, delete=False):
......
import os import os
import torch.utils.data as data from .vision import VisionDataset
import numpy as np import numpy as np
...@@ -8,7 +8,7 @@ from .utils import download_url ...@@ -8,7 +8,7 @@ from .utils import download_url
from .voc import download_extract from .voc import download_extract
class SBDataset(data.Dataset): class SBDataset(VisionDataset):
"""`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_ """`Semantic Boundaries Dataset <http://home.bharathh.info/pubs/codes/SBD/download.html>`_
The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset. The SBD currently contains annotations from 11355 images taken from the PASCAL VOC 2011 dataset.
...@@ -62,10 +62,11 @@ class SBDataset(data.Dataset): ...@@ -62,10 +62,11 @@ class SBDataset(data.Dataset):
raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: " raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: "
"pip install scipy") "pip install scipy")
super(SBDataset, self).__init__(root)
if mode not in ("segmentation", "boundaries"): if mode not in ("segmentation", "boundaries"):
raise ValueError("Argument mode should be 'segmentation' or 'boundaries'") raise ValueError("Argument mode should be 'segmentation' or 'boundaries'")
self.root = os.path.expanduser(root)
self.xy_transform = xy_transform self.xy_transform = xy_transform
self.image_set = image_set self.image_set = image_set
self.mode = mode self.mode = mode
...@@ -121,3 +122,7 @@ class SBDataset(data.Dataset): ...@@ -121,3 +122,7 @@ class SBDataset(data.Dataset):
def __len__(self): def __len__(self):
return len(self.images) return len(self.images)
def extra_repr(self):
lines = ["Image set: {image_set}", "Mode: {mode}"]
return '\n'.join(lines).format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment