Unverified Commit 262d6177 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.omniglot (#2533)

parent ec9c7a54
from PIL import Image
from os.path import join
import os
from typing import Any, Callable, List, Optional, Tuple
from .vision import VisionDataset
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
......@@ -27,8 +28,14 @@ class Omniglot(VisionDataset):
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
}
def __init__(self, root, background=True, transform=None, target_transform=None,
download=False):
def __init__(
self,
root: str,
background: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
target_transform=target_transform)
self.background = background
......@@ -42,16 +49,16 @@ class Omniglot(VisionDataset):
self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder)
self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._alphabets], [])
self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._alphabets], [])
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
for idx, character in enumerate(self._characters)]
self._flat_character_images = sum(self._character_images, [])
self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
def __len__(self):
def __len__(self) -> int:
return len(self._flat_character_images)
def __getitem__(self, index):
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
......@@ -71,13 +78,13 @@ class Omniglot(VisionDataset):
return image, character_class
def _check_integrity(self):
def _check_integrity(self) -> bool:
zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
return False
return True
def download(self):
def download(self) -> None:
if self._check_integrity():
print('Files already downloaded and verified')
return
......@@ -87,5 +94,5 @@ class Omniglot(VisionDataset):
url = self.download_url_prefix + '/' + zip_filename
download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
def _get_target_folder(self):
def _get_target_folder(self) -> str:
return 'images_background' if self.background else 'images_evaluation'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment