Commit e8e04f06 authored by Adam J. Stewart's avatar Adam J. Stewart Committed by Francisco Massa
Browse files

Add Flickr8k and Flickr30k Datasets (#674)

* Add Flickr8k and Flickr30k Datasets

* Add Flickr to the dataset docs

* Sort ids, glob during construction

* annFile -> ann_file

* Fix undefined variable name bug
parent 878a771f
......@@ -138,3 +138,14 @@ SBU
:members: __getitem__
:special-members:
Flickr
~~~~~~
.. autoclass:: Flickr8k
:members: __getitem__
:special-members:
.. autoclass:: Flickr30k
:members: __getitem__
:special-members:
......@@ -10,10 +10,11 @@ from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot
from .sbu import SBU
from .flickr import Flickr8k, Flickr30k
__all__ = ('LSUN', 'LSUNClass',
'ImageFolder', 'DatasetFolder', 'FakeData',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST',
'MNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
'Omniglot', 'SBU')
'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k')
from collections import defaultdict
from PIL import Image
from six.moves import html_parser
import glob
import os
import torch.utils.data as data
class Flickr8kParser(html_parser.HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root):
super(Flickr8kParser, self).__init__()
self.root = root
# Data structure to store captions
self.annotations = {}
# State variables
self.in_table = False
self.current_tag = None
self.current_img = None
def handle_starttag(self, tag, attrs):
self.current_tag = tag
if tag == 'table':
self.in_table = True
def handle_endtag(self, tag):
self.current_tag = None
if tag == 'table':
self.in_table = False
def handle_data(self, data):
if self.in_table:
if data == 'Image Not Found':
self.current_img = None
elif self.current_tag == 'a':
img_id = data.split('/')[-2]
img_id = os.path.join(self.root, img_id + '_*.jpg')
img_id = glob.glob(img_id)[0]
self.current_img = img_id
self.annotations[img_id] = []
elif self.current_tag == 'li' and self.current_img:
img_id = self.current_img
self.annotations[img_id].append(data.strip())
class Flickr8k(data.Dataset):
"""`Flickr8k Entities <http://nlp.cs.illinois.edu/HockenmaierGroup/8k-pictures.html>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
self.transform = transform
self.target_transform = target_transform
# Read annotations and store in a dict
parser = Flickr8kParser(self.root)
with open(self.ann_file) as fh:
parser.feed(fh.read())
self.annotations = parser.annotations
self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
img_id = self.ids[index]
# Image
img = Image.open(img_id).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# Captions
target = self.annotations[img_id]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
class Flickr30k(data.Dataset):
"""`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.
Args:
root (string): Root directory where images are downloaded to.
ann_file (string): Path to annotation file.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.ToTensor``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
self.root = os.path.expanduser(root)
self.ann_file = os.path.expanduser(ann_file)
self.transform = transform
self.target_transform = target_transform
# Read annotations and store in a dict
self.annotations = defaultdict(list)
with open(self.ann_file) as fh:
for line in fh:
img_id, caption = line.strip().split('\t')
self.annotations[img_id[:-2]].append(caption)
self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: Tuple (image, target). target is a list of captions for the image.
"""
img_id = self.ids[index]
# Image
filename = os.path.join(self.root, img_id)
img = Image.open(filename).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# Captions
target = self.annotations[img_id]
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment