Commit 9d9f48a3 authored by Bryan He's avatar Bryan He Committed by Francisco Massa
Browse files

Add Caltech101, Caltech256, and CelebA (#775)

* Add Caltech101 and Caltech256

* Add information about default for target_type

* Fix docs

* Add function to download from Google Drive

* Add CelebA dataset

* Only import pandas when needed

* Addressing comments

* Remove trailing whitespace

* Replace torch.LongTensor with torch.as_tensor
parent ab86a3a1
......@@ -13,6 +13,8 @@ from .sbu import SBU
from .flickr import Flickr8k, Flickr30k
from .voc import VOCSegmentation, VOCDetection
from .cityscapes import Cityscapes
from .caltech import Caltech101, Caltech256
from .celeba import CelebA
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
......@@ -20,4 +22,5 @@ __all__ = ('LSUN', 'LSUNClass',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
'VOCSegmentation', 'VOCDetection', 'Cityscapes')
'VOCSegmentation', 'VOCDetection', 'Cityscapes',
'Caltech101', 'Caltech256', 'CelebA')
from __future__ import print_function
from PIL import Image
import os
import os.path
import numpy as np
import sys
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
import collections
import torch.utils.data as data
from .utils import download_url, check_integrity, makedir_exist_ok
class Caltech101(data.Dataset):
"""`Caltech 101 <http://www.vision.caltech.edu/Image_Datasets/Caltech101/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``caltech101`` exists or will be saved to if download is set to True.
target_type (string or list, optional): Type of target to use, ``category`` or
``annotation``. Can also be a list to output a tuple with all specified target types.
``category`` represents the target class, and ``annotation`` is a list of points
from a hand-generated outline. Defaults to ``category``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(self, root, target_type="category",
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech101")
makedir_exist_ok(self.root)
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
self.categories.remove("BACKGROUND_Google") # this is not a real class
# For some reason, the category names in "101_ObjectCategories" and
# "Annotations" do not always match. This is a manual map between the
# two. Defaults to using same name, since most names are fine.
name_map = {"Faces": "Faces_2",
"Faces_easy": "Faces_3",
"Motorbikes": "Motorbikes_16",
"airplanes": "Airplanes_Side_2"}
self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where the type of target specified by target_type.
"""
import scipy.io
img = Image.open(os.path.join(self.root,
"101_ObjectCategories",
self.categories[self.y[index]],
"image_{:04d}.jpg".format(self.index[index])))
target = []
for t in self.target_type:
if t == "category":
target.append(self.y[index])
elif t == "annotation":
data = scipy.io.loadmat(os.path.join(self.root,
"Annotations",
self.annotation_categories[self.y[index]],
"annotation_{:04d}.mat".format(self.index[index])))
target.append(data["obj_contour"])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
def __len__(self):
return len(self.index)
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz",
self.root,
"101_ObjectCategories.tar.gz",
"b224c7392d521a49829488ab0f1120d9")
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech101/Annotations.tar",
self.root,
"101_Annotations.tar",
"6f83eeb1f24d99cab4eb377263132c91")
# extract file
with tarfile.open(os.path.join(self.root, "101_ObjectCategories.tar.gz"), "r:gz") as tar:
tar.extractall(path=self.root)
with tarfile.open(os.path.join(self.root, "101_Annotations.tar"), "r:") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
class Caltech256(data.Dataset):
"""`Caltech 256 <http://www.vision.caltech.edu/Image_Datasets/Caltech256/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory
``caltech256`` exists or will be saved to if download is set to True.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
def __init__(self, root,
transform=None, target_transform=None,
download=False):
self.root = os.path.join(os.path.expanduser(root), "caltech256")
makedir_exist_ok(self.root)
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
self.index = []
self.y = []
for (i, c) in enumerate(self.categories):
n = len(os.listdir(os.path.join(self.root, "256_ObjectCategories", c)))
self.index.extend(range(1, n + 1))
self.y.extend(n * [i])
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img = Image.open(os.path.join(self.root,
"256_ObjectCategories",
self.categories[self.y[index]],
"{:03d}_{:04d}.jpg".format(self.y[index] + 1, self.index[index])))
target = self.y[index]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def _check_integrity(self):
# can be more robust and check hash of files
return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
def __len__(self):
return len(self.index)
def download(self):
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url("http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar",
self.root,
"256_ObjectCategories.tar",
"67b4f42ca05d46448c6bb8ecd2220f6d")
# extract file
with tarfile.open(os.path.join(self.root, "256_ObjectCategories.tar"), "r:") as tar:
tar.extractall(path=self.root)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
import torch
import torch.utils.data as data
import os
import PIL
from .utils import download_file_from_google_drive, check_integrity
class CelebA(data.Dataset):
"""`Large-scale CelebFaces Attributes (CelebA) Dataset <http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
split (string): One of {'train', 'valid', 'test'}.
Accordingly dataset is selected.
target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``,
or ``landmarks``. Can also be a list to output a tuple with all specified target types.
The targets represent:
``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes
``identity`` (int): label for each person (data points with the same identity are the same person)
``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height)
``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x,
righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y)
Defaults to ``attr``.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
base_folder = "celeba"
# There currently does not appear to be a easy way to extract 7z in python (without introducing additional
# dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available
# right now.
file_list = [
# File ID MD5 Hash Filename
("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"),
# ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"),
# ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"),
("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"),
("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"),
("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"),
("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"),
# ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"),
("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"),
]
def __init__(self, root,
split="train",
target_type="attr",
transform=None, target_transform=None,
download=False):
import pandas
self.root = os.path.expanduser(root)
self.split = split
if isinstance(target_type, list):
self.target_type = target_type
else:
self.target_type = [target_type]
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
self.transform = transform
self.target_transform = target_transform
if split.lower() == "train":
split = 0
elif split.lower() == "valid":
split = 1
elif split.lower() == "test":
split = 2
else:
raise ValueError('Wrong split entered! Please use split="train" '
'or split="valid" or split="test"')
with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f:
splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f:
self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0)
with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f:
self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0)
with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f:
self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1)
with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f:
self.attr = pandas.read_csv(f, delim_whitespace=True, header=1)
mask = (splits[1] == split)
self.filename = splits[mask].index.values
self.identity = torch.as_tensor(self.identity[mask].values)
self.bbox = torch.as_tensor(self.bbox[mask].values)
self.landmarks_align = torch.as_tensor(self.landmarks_align[mask].values)
self.attr = torch.as_tensor(self.attr[mask].values)
self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1}
def _check_integrity(self):
for (_, md5, filename) in self.file_list:
fpath = os.path.join(self.root, self.base_folder, filename)
_, ext = os.path.splitext(filename)
# Allow original archive to be deleted (zip and 7z)
# Only need the extracted images
if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
return False
# Should check a hash of the images
return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))
def download(self):
import zipfile
if self._check_integrity():
print('Files already downloaded and verified')
return
for (file_id, md5, filename) in self.file_list:
download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)
with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
f.extractall(os.path.join(self.root, self.base_folder))
def __getitem__(self, index):
X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))
target = []
for t in self.target_type:
if t == "attr":
target.append(self.attr[index, :])
elif t == "identity":
target.append(self.identity[index, 0])
elif t == "bbox":
target.append(self.bbox[index, :])
elif t == "landmarks":
target.append(self.landmarks_align[index, :])
else:
raise ValueError("Target type \"{}\" is not recognized.".format(t))
target = tuple(target) if len(target) > 1 else target[0]
if self.transform is not None:
X = self.transform(X)
if self.target_transform is not None:
target = self.target_transform(target)
return X, target
def __len__(self):
return len(self.attr)
def __repr__(self):
fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
fmt_str += ' Number of datapoints: {}\n'.format(self.__len__())
fmt_str += ' Target type: {}\n'.format(self.target_type)
fmt_str += ' Split: {}\n'.format(self.split)
fmt_str += ' Root Location: {}\n'.format(self.root)
tmp = ' Transforms (if any): '
fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
tmp = ' Target Transforms (if any): '
fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
return fmt_str
......@@ -52,8 +52,8 @@ def download_url(url, root, filename=None, md5=None):
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str): Name to save the file under. If None, use the basename of the URL
md5 (str): MD5 checksum of the download. If None, do not check
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
from six.moves import urllib
......@@ -129,3 +129,58 @@ def list_files(root, suffix, prefix=False):
files = [os.path.join(root, d) for d in files]
return files
def download_file_from_google_drive(file_id, root, filename=None, md5=None):
"""Download a Google Drive file from and place it in root.
Args:
file_id (str): id of file to be downloaded
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the id of the file.
md5 (str, optional): MD5 checksum of the download. If None, do not check
"""
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
import requests
url = "https://docs.google.com/uc?export=download"
root = os.path.expanduser(root)
if not filename:
filename = file_id
fpath = os.path.join(root, filename)
makedir_exist_ok(root)
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
session = requests.Session()
response = session.get(url, params={'id': file_id}, stream=True)
token = _get_confirm_token(response)
if token:
params = {'id': file_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
_save_response_content(response, fpath)
def _get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def _save_response_content(response, destination, chunk_size=32768):
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
for chunk in response.iter_content(chunk_size):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
progress += len(chunk)
pbar.update(progress - pbar.n)
pbar.close()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment