Commit 9d9f48a3 authored by Bryan He's avatar Bryan He Committed by Francisco Massa
Browse files

Add Caltech101, Caltech256, and CelebA (#775)

* Add Caltech101 and Caltech256

* Add information about default for target_type

* Fix docs

* Add function to download from Google Drive

* Add CelebA dataset

* Only import pandas when needed

* Addressing comments

* Remove trailing whitespace

* Replace torch.LongTensor with torch.as_tensor
parent ab86a3a1
...@@ -13,6 +13,8 @@ from .sbu import SBU ...@@ -13,6 +13,8 @@ from .sbu import SBU
from .flickr import Flickr8k, Flickr30k from .flickr import Flickr8k, Flickr30k
from .voc import VOCSegmentation, VOCDetection from .voc import VOCSegmentation, VOCDetection
from .cityscapes import Cityscapes from .cityscapes import Cityscapes
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData', 'ImageFolder', 'DatasetFolder', 'FakeData',
...@@ -20,4 +22,5 @@ __all__ = ('LSUN', 'LSUNClass', ...@@ -20,4 +22,5 @@ __all__ = ('LSUN', 'LSUNClass',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k', 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes') 'VOCSegmentation', 'VOCDetection', 'Cityscapes',
'Caltech101', 'Caltech256', 'CelebA')
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import collections
import torch.utils.data as data
from .utils import download_url, check_integrity, makedir_exist_ok
class Caltech101(data.Dataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified target types.
``category`` represents the target class, and ``annotation`` is a list of points
from a hand-generated outline. Defaults to ``category``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(self, root, target_type="category",
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech101")
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
self.categories.remove("BACKGROUND_Google") # this is not a real class
# For some reason, the category names in "101_ObjectCategories" and
# "Annotations" do not always match. This is a manual map between the
# two. Defaults to using same name, since most names are fine.
name_map = {"Faces": "Faces_2",
"Faces_easy": "Faces_3",
"Motorbikes": "Motorbikes_16",
"airplanes": "Airplanes_Side_2"}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where the type of target specified by target_type.
"""
import scipy.io
img = Image.open(os.path.join(self.root,
"101_ObjectCategories",
self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index])))
target = []
for t in self.target_type:
if t == "category":
target.append(self.y[index])
elif t == "annotation":
data = scipy.io.loadmat(os.path.join(self.root,
"Annotations",
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index])))
target.append(data["obj_contour"])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
def __len__(self):
return len(self.index)
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")
# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
class Caltech256(data.Dataset):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(self, root,
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech256")
makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img = Image.open(os.path.join(self.root,
"256_ObjectCategories",
self.categories[self.y[index]],
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index])))
target = self.y[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
def __len__(self):
return len(self.index)
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")
# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
import torch
import torch.utils.data as data
import os
import PIL
from .utils import download_file_from_google_drive, check_integrity
class CelebA(data.Dataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test'}.
Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
or ``landmarks``. Can also be a list to output a tuple with all specified target types.
The targets represent:
``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
``identity`` (int): label for each person (data points with the same identity are the same person)
``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
Defaults to ``attr``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = "celeba"
# There currently does not appear to be a easy way to extract 7z in python (without introducing additional
# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
# right now.
file_list = [
# File ID MD5 Hash Filename
("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
# ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
# ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
# ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
]
def __init__(self, root,
split="train",
target_type="attr",
transform=None, target_transform=None,
download=False):
import pandas
self.root = os.path.expanduser(root)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.transform = transform
self.target_transform = target_transform
if split.lower() == "train":
split = 0
elif split.lower() == "valid":
split = 1
elif split.lower() == "test":
split = 2
else:
raise ValueError('Wrong split entered! Please use split="train" '
'or split="valid" or split="test"')
with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)
with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)
with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)
mask = (splits[1] == split)
self.filename = splits[mask].index.values
self.identity = torch.as_tensor(self.identity[mask].values)
self.bbox = torch.as_tensor(self.bbox[mask].values)
self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
self.attr = torch.as_tensor(self.attr[mask].values)
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
def _check_integrity(self):
for (_, md5, filename) in self.file_list:
fpath = os.path.join(self.root, self.base_folder, filename)
_, ext = os.path.splitext(filename)
# Allow original archive to be deleted (zip and 7z)
# Only need the extracted images
if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
return False
# Should check a hash of the images
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
def download(self):
import zipfile
if self._check_integrity():
print('Files already downloaded and verified')
return
for (file_id, md5, filename) in self.file_list:
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
f.extractall(os.path.join(self.root, self.base_folder))
def __getitem__(self, index):
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
target = []
for t in self.target_type:
if t == "attr":
target.append(self.attr[index, :])
elif t == "identity":
target.append(self.identity[index, 0])
elif t == "bbox":
target.append(self.bbox[index, :])
elif t == "landmarks":
target.append(self.landmarks_align[index, :])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]
if self.transform is not None:
X = self.transform(X)
if self.target_transform is not None:
target = self.target_transform(target)
return X, target
def __len__(self):
return len(self.attr)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
...@@ -52,8 +52,8 @@ def download_url(url, root, filename=None, md5=None): ...@@ -52,8 +52,8 @@ def download_url(url, root, filename=None, md5=None):
Args: Args:
url (str): URL to download file from url (str): URL to download file from
root (str): Directory to place downloaded file in root (str): Directory to place downloaded file in
filename (str): Name to save the file under. If None, use the basename of the URL filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str): MD5 checksum of the download. If None, do not check md5 (str, optional): MD5 checksum of the download. If None, do not check
""" """
from six.moves import urllib from six.moves import urllib
...@@ -129,3 +129,58 @@ def list_files(root, suffix, prefix=False): ...@@ -129,3 +129,58 @@ def list_files(root, suffix, prefix=False):
files = [os.path.join(root, d) for d in files] files = [os.path.join(root, d) for d in files]
return files return files
def download_file_from_google_drive(file_id, root, filename=None, md5=None):
"""Download a Google Drive file from and place it in root.
Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
makedir_exist_ok(root)
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
session = requests.Session()
response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
_save_response_content(response, fpath)
def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def _save_response_content(response, destination, chunk_size=32768):
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment