Unverified Commit 15bd87f2 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

flickr (#2529)

parent f1d7c92d
from collections import defaultdict
from PIL import Image
from html.parser import HTMLParser
from typing import Any, Callable, Dict, List, Optional, Tuple
import glob
import os
......@@ -10,32 +11,32 @@ from .vision import VisionDataset
class Flickr8kParser(HTMLParser):
"""Parser for extracting captions from the Flickr8k dataset web page."""
def __init__(self, root):
def __init__(self, root: str) -> None:
super(Flickr8kParser, self).__init__()
self.root = root
# Data structure to store captions
self.annotations = {}
self.annotations: Dict[str, List[str]] = {}
# State variables
self.in_table = False
self.current_tag = None
self.current_img = None
self.current_tag: Optional[str] = None
self.current_img: Optional[str] = None
def handle_starttag(self, tag, attrs):
def handle_starttag(self, tag: str, attrs: List[Tuple[str, Optional[str]]]) -> None:
self.current_tag = tag
if tag == 'table':
self.in_table = True
def handle_endtag(self, tag):
def handle_endtag(self, tag: str) -> None:
self.current_tag = None
if tag == 'table':
self.in_table = False
def handle_data(self, data):
def handle_data(self, data: str) -> None:
if self.in_table:
if data == 'Image Not Found':
self.current_img = None
......@@ -62,7 +63,13 @@ class Flickr8k(VisionDataset):
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
def __init__(
self,
root: str,
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(Flickr8k, self).__init__(root, transform=transform,
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
......@@ -75,7 +82,7 @@ class Flickr8k(VisionDataset):
self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -97,7 +104,7 @@ class Flickr8k(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.ids)
......@@ -113,7 +120,13 @@ class Flickr30k(VisionDataset):
target and transforms it.
"""
def __init__(self, root, ann_file, transform=None, target_transform=None):
def __init__(
self,
root: str,
ann_file: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
super(Flickr30k, self).__init__(root, transform=transform,
target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
......@@ -127,7 +140,7 @@ class Flickr30k(VisionDataset):
self.ids = list(sorted(self.annotations.keys()))
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -150,5 +163,5 @@ class Flickr30k(VisionDataset):
return img, target
def __len__(self):
def __len__(self) -> int:
return len(self.ids)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment