flickr.py 5.38 KB
Newer Older
1
2
from collections import defaultdict
from PIL import Image
Philip Meier's avatar
Philip Meier committed
3
from html.parser import HTMLParser
Philip Meier's avatar
Philip Meier committed
4
from typing import Any, Callable, Dict, List, Optional, Tuple
5
6
7

import glob
import os
8
from .vision import VisionDataset
9
10


Philip Meier's avatar
Philip Meier committed
11
class Flickr8kParser(HTMLParser):
12
13
    """Parser for extracting captions from the Flickr8k dataset web page."""

Philip Meier's avatar
Philip Meier committed
14
    def __init__(self, root: str) -> None:
15
16
17
18
19
        super(Flickr8kParser, self).__init__()

        self.root = root

        # Data structure to store captions
Philip Meier's avatar
Philip Meier committed
20
        self.annotations: Dict[str, List[str]] = {}
21
22
23

        # State variables
        self.in_table = False
Philip Meier's avatar
Philip Meier committed
24
25
        self.current_tag: Optional[str] = None
        self.current_img: Optional[str] = None
26

Philip Meier's avatar
Philip Meier committed
27
    def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
28
29
30
31
32
        self.current_tag = tag

        if tag == 'table':
            self.in_table = True

Philip Meier's avatar
Philip Meier committed
33
    def handle_endtag(self, tag: str) -> None:
34
35
36
37
38
        self.current_tag = None

        if tag == 'table':
            self.in_table = False

Philip Meier's avatar
Philip Meier committed
39
    def handle_data(self, data: str) -> None:
40
41
42
43
44
45
46
47
48
49
50
51
52
53
        if self.in_table:
            if data == 'Image Not Found':
                self.current_img = None
            elif self.current_tag == 'a':
                img_id = data.split('/')[-2]
                img_id = os.path.join(self.root, img_id + '_*.jpg')
                img_id = glob.glob(img_id)[0]
                self.current_img = img_id
                self.annotations[img_id] = []
            elif self.current_tag == 'li' and self.current_img:
                img_id = self.current_img
                self.annotations[img_id].append(data.strip())


54
class Flickr8k(VisionDataset):
55
    """`Flickr8k Entities <http://hockenmaier.cs.illinois.edu/8k-pictures.html>`_ Dataset.
56
57
58
59
60
61
62
63
64

    Args:
        root (string): Root directory where images are downloaded to.
        ann_file (string): Path to annotation file.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
65

Philip Meier's avatar
Philip Meier committed
66
67
68
69
70
71
72
    def __init__(
            self,
            root: str,
            ann_file: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
73
74
        super(Flickr8k, self).__init__(root, transform=transform,
                                       target_transform=target_transform)
75
        self.ann_file = os.path.expanduser(ann_file)
76
77
78
79
80
81
82
83
84

        # Read annotations and store in a dict
        parser = Flickr8kParser(self.root)
        with open(self.ann_file) as fh:
            parser.feed(fh.read())
        self.annotations = parser.annotations

        self.ids = list(sorted(self.annotations.keys()))

Philip Meier's avatar
Philip Meier committed
85
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is a list of captions for the image.
        """
        img_id = self.ids[index]

        # Image
        img = Image.open(img_id).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        # Captions
        target = self.annotations[img_id]
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
107
    def __len__(self) -> int:
108
109
110
        return len(self.ids)


111
class Flickr30k(VisionDataset):
112
113
114
115
116
117
118
119
120
121
    """`Flickr30k Entities <http://web.engr.illinois.edu/~bplumme2/Flickr30kEntities/>`_ Dataset.

    Args:
        root (string): Root directory where images are downloaded to.
        ann_file (string): Path to annotation file.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.ToTensor``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
    """
122

Philip Meier's avatar
Philip Meier committed
123
124
125
126
127
128
129
    def __init__(
            self,
            root: str,
            ann_file: str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
130
131
        super(Flickr30k, self).__init__(root, transform=transform,
                                        target_transform=target_transform)
132
        self.ann_file = os.path.expanduser(ann_file)
133
134
135
136
137
138
139
140
141
142

        # Read annotations and store in a dict
        self.annotations = defaultdict(list)
        with open(self.ann_file) as fh:
            for line in fh:
                img_id, caption = line.strip().split('\t')
                self.annotations[img_id[:-2]].append(caption)

        self.ids = list(sorted(self.annotations.keys()))

Philip Meier's avatar
Philip Meier committed
143
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        """
        Args:
            index (int): Index

        Returns:
            tuple: Tuple (image, target). target is a list of captions for the image.
        """
        img_id = self.ids[index]

        # Image
        filename = os.path.join(self.root, img_id)
        img = Image.open(filename).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)

        # Captions
        target = self.annotations[img_id]
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

Philip Meier's avatar
Philip Meier committed
166
    def __len__(self) -> int:
167
        return len(self.ids)