Commit 878a771f authored by Adam J. Stewart's avatar Adam J. Stewart Committed by Francisco Massa
Browse files

Add SBU Captioned Photo Dataset (#665)

* Add SBU Captioned Photo Dataset

* Add SBU to the dataset docs
parent d5637696
...@@ -129,3 +129,12 @@ PhotoTour ...@@ -129,3 +129,12 @@ PhotoTour
.. autoclass:: PhotoTour .. autoclass:: PhotoTour
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
SBU
~~~
.. autoclass:: SBU
:members: __getitem__
:special-members:
...@@ -9,10 +9,11 @@ from .phototour import PhotoTour ...@@ -9,10 +9,11 @@ from .phototour import PhotoTour
from .fakedata import FakeData from .fakedata import FakeData
from .semeion import SEMEION from .semeion import SEMEION
from .omniglot import Omniglot from .omniglot import Omniglot
from .sbu import SBU
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData', 'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot') 'Omniglot', 'SBU')
from PIL import Image
from six.moves import zip
from .utils import download_url, check_integrity
import os
import torch.utils.data as data
class SBU(data.Dataset):
"""`SBU Captioned Photo <http://www.cs.virginia.edu/~vicente/sbucaptions/>`_ Dataset.
Args:
root (string): Root directory of dataset where tarball
``SBUCaptionedPhotoDataset.tar.gz`` exists.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
download (bool, optional): If True, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
url = "http://www.cs.virginia.edu/~vicente/sbucaptions/SBUCaptionedPhotoDataset.tar.gz"
filename = "SBUCaptionedPhotoDataset.tar.gz"
md5_checksum = '9aec147b3488753cf758b4d493422285'
def __init__(self, root, transform=None, target_transform=None, download=True):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
if download:
self.download()
if not self._check_integrity():
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
# Read the caption for each photo
self.photos = []
self.captions = []
file1 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')
file2 = os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_captions.txt')
for line1, line2 in zip(open(file1), open(file2)):
url = line1.rstrip()
photo = os.path.basename(url)
filename = os.path.join(self.root, 'dataset', photo)
if os.path.exists(filename):
caption = line2.rstrip()
self.photos.append(photo)
self.captions.append(caption)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a caption for the photo.
"""
filename = os.path.join(self.root, 'dataset', self.photos[index])
img = Image.open(filename).convert('RGB')
if self.transform is not None:
img = self.transform(img)
target = self.captions[index]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
"""The number of photos in the dataset."""
return len(self.photos)
def _check_integrity(self):
"""Check the md5 checksum of the downloaded tarball."""
root = self.root
fpath = os.path.join(root, self.filename)
if not check_integrity(fpath, self.md5_checksum):
return False
return True
def download(self):
"""Download and extract the tarball, and download each individual photo."""
import tarfile
if self._check_integrity():
print('Files already downloaded and verified')
return
download_url(self.url, self.root, self.filename, self.md5_checksum)
# Extract file
with tarfile.open(os.path.join(self.root, self.filename), 'r:gz') as tar:
tar.extractall(path=self.root)
# Download individual photos
with open(os.path.join(self.root, 'dataset', 'SBU_captioned_photo_dataset_urls.txt')) as fh:
for line in fh:
url = line.rstrip()
try:
download_url(url, os.path.join(self.root, 'dataset'))
except OSError:
# The images point to public images on Flickr.
# Note: Images might be removed by users at anytime.
pass
...@@ -44,10 +44,20 @@ def makedir_exist_ok(dirpath): ...@@ -44,10 +44,20 @@ def makedir_exist_ok(dirpath):
raise raise
def download_url(url, root, filename, md5): def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str): Name to save the file under. If None, use the basename of the URL
md5 (str): MD5 checksum of the download. If None, do not check
"""
from six.moves import urllib from six.moves import urllib
root = os.path.expanduser(root) root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename) fpath = os.path.join(root, filename)
makedir_exist_ok(root) makedir_exist_ok(root)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment