Unverified Commit 15bd87f2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

flickr (#2529)

parent f1d7c92d
from collections import defaultdict from collections import defaultdict
from PIL import Image from PIL import Image
from html.parser import HTMLParser from html.parser import HTMLParser
from typing import Any, Callable, Dict, List, Optional, Tuple
import glob import glob
import os import os
...@@ -10,32 +11,32 @@ from .vision import VisionDataset ...@@ -10,32 +11,32 @@ from .vision import VisionDataset
class Flickr8kParser(HTMLParser): class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page.""" """Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root): def __init__(self, root: str) -> None:
super(Flickr8kParser, self).__init__() super(Flickr8kParser, self).__init__()
self.root = root self.root = root
# Data structure to store captions # Data structure to store captions
self.annotations = {} self.annotations: Dict[str, List[str]] = {}
# State variables # State variables
self.in_table = False self.in_table = False
self.current_tag = None self.current_tag: Optional[str] = None
self.current_img = None self.current_img: Optional[str] = None
def handle_starttag(self, tag, attrs): def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
self.current_tag = tag self.current_tag = tag
if tag == 'table': if tag == 'table':
self.in_table = True self.in_table = True
def handle_endtag(self, tag): def handle_endtag(self, tag: str) -> None:
self.current_tag = None self.current_tag = None
if tag == 'table': if tag == 'table':
self.in_table = False self.in_table = False
def handle_data(self, data): def handle_data(self, data: str) -> None:
if self.in_table: if self.in_table:
if data == 'Image Not Found': if data == 'Image Not Found':
self.current_img = None self.current_img = None
...@@ -62,7 +63,13 @@ class Flickr8k(VisionDataset): ...@@ -62,7 +63,13 @@ class Flickr8k(VisionDataset):
target and transforms it. target and transforms it.
""" """
def __init__(self, root, ann_file, transform=None, target_transform=None): def __init__(
self,
root: str,
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(Flickr8k, self).__init__(root, transform=transform, super(Flickr8k, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
...@@ -75,7 +82,7 @@ class Flickr8k(VisionDataset): ...@@ -75,7 +82,7 @@ class Flickr8k(VisionDataset):
self.ids = list(sorted(self.annotations.keys())) self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -97,7 +104,7 @@ class Flickr8k(VisionDataset): ...@@ -97,7 +104,7 @@ class Flickr8k(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.ids) return len(self.ids)
...@@ -113,7 +120,13 @@ class Flickr30k(VisionDataset): ...@@ -113,7 +120,13 @@ class Flickr30k(VisionDataset):
target and transforms it. target and transforms it.
""" """
def __init__(self, root, ann_file, transform=None, target_transform=None): def __init__(
self,
root: str,
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(Flickr30k, self).__init__(root, transform=transform, super(Flickr30k, self).__init__(root, transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file) self.ann_file = os.path.expanduser(ann_file)
...@@ -127,7 +140,7 @@ class Flickr30k(VisionDataset): ...@@ -127,7 +140,7 @@ class Flickr30k(VisionDataset):
self.ids = list(sorted(self.annotations.keys())) self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -150,5 +163,5 @@ class Flickr30k(VisionDataset): ...@@ -150,5 +163,5 @@ class Flickr30k(VisionDataset):
return img, target return img, target
def __len__(self): def __len__(self) -> int:
return len(self.ids) return len(self.ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment