Commit e8e04f06 authored by Adam J. Stewart's avatar Adam J. Stewart Committed by Francisco Massa
Browse files

Add Flickr8k and Flickr30k Datasets (#674)

* Add Flickr8k and Flickr30k Datasets

* Add Flickr to the dataset docs

* Sort ids, glob during construction

* annFile -> ann_file

* Fix undefined variable name bug
parent 878a771f
...@@ -138,3 +138,14 @@ SBU ...@@ -138,3 +138,14 @@ SBU
:members: __getitem__ :members: __getitem__
:special-members: :special-members:
Flickr
~~~~~~
.. autoclass:: Flickr8k
:members: __getitem__
:special-members:
.. autoclass:: Flickr30k
:members: __getitem__
:special-members:
...@@ -10,10 +10,11 @@ from .fakedata import FakeData ...@@ -10,10 +10,11 @@ from .fakedata import FakeData
from .semeion import SEMEION from .semeion import SEMEION
from .omniglot import Omniglot from .omniglot import Omniglot
from .sbu import SBU from .sbu import SBU
from .flickr import Flickr8k, Flickr30k
__all__ = ('LSUN', 'LSUNClass', __all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData', 'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection', 'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION', 'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU') 'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k')
from collections import defaultdict
from PIL import Image
from six.moves import html_parser
import glob
import os
import torch.utils.data as data
class Flickr8kParser(html_parser.HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root):
super(Flickr8kParser, self).__init__()
self.root = root
# Data structure to store captions
self.annotations = {}
# State variables
self.in_table = False
self.current_tag = None
self.current_img = None
def handle_starttag(self, tag, attrs):
self.current_tag = tag
if tag == 'table':
self.in_table = True
def handle_endtag(self, tag):
self.current_tag = None
if tag == 'table':
self.in_table = False
def handle_data(self, data):
if self.in_table:
if data == 'Image Not Found':
self.current_img = None
elif self.current_tag == 'a':
img_id = data.split('/')[-2]
img_id = os.path.join(self.root, img_id + '_*.jpg')
img_id = glob.glob(img_id)[0]
self.current_img = img_id
self.annotations[img_id] = []
elif self.current_tag == 'li' and self.current_img:
img_id = self.current_img
self.annotations[img_id].append(data.strip())
class Flickr8k(data.Dataset):
"""`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
self.transform = transform
self.target_transform = target_transform
# Read annotations and store in a dict
parser = Flickr8kParser(self.root)
with open(self.ann_file) as fh:
parser.feed(fh.read())
self.annotations = parser.annotations
self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
img_id = self.ids[index]
# Image
img = Image.open(img_id).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# Captions
target = self.annotations[img_id]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
class Flickr30k(data.Dataset):
"""`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
self.transform = transform
self.target_transform = target_transform
# Read annotations and store in a dict
self.annotations = defaultdict(list)
with open(self.ann_file) as fh:
for line in fh:
img_id, caption = line.strip().split('\t')
self.annotations[img_id[:-2]].append(caption)
self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
img_id = self.ids[index]
# Image
filename = os.path.join(self.root, img_id)
img = Image.open(filename).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# Captions
target = self.annotations[img_id]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment