Unverified Commit 262d6177 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add typehints for torchvision.datasets.omniglot (#2533)

parent ec9c7a54
from PIL import Image from PIL import Image
from os.path import join from os.path import join
import os import os
from typing import Any, Callable, List, Optional, Tuple
from .vision import VisionDataset from .vision import VisionDataset
from .utils import download_and_extract_archive, check_integrity, list_dir, list_files from .utils import download_and_extract_archive, check_integrity, list_dir, list_files
...@@ -27,8 +28,14 @@ class Omniglot(VisionDataset): ...@@ -27,8 +28,14 @@ class Omniglot(VisionDataset):
'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811' 'images_evaluation': '6b91aef0f799c5bb55b94e3f2daec811'
} }
def __init__(self, root, background=True, transform=None, target_transform=None, def __init__(
download=False): self,
root: str,
background: bool = True,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
super(Omniglot, self).__init__(join(root, self.folder), transform=transform, super(Omniglot, self).__init__(join(root, self.folder), transform=transform,
target_transform=target_transform) target_transform=target_transform)
self.background = background self.background = background
...@@ -42,16 +49,16 @@ class Omniglot(VisionDataset): ...@@ -42,16 +49,16 @@ class Omniglot(VisionDataset):
self.target_folder = join(self.root, self._get_target_folder()) self.target_folder = join(self.root, self._get_target_folder())
self._alphabets = list_dir(self.target_folder) self._alphabets = list_dir(self.target_folder)
self._characters = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))] self._characters: List[str] = sum([[join(a, c) for c in list_dir(join(self.target_folder, a))]
for a in self._alphabets], []) for a in self._alphabets], [])
self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')] self._character_images = [[(image, idx) for image in list_files(join(self.target_folder, character), '.png')]
for idx, character in enumerate(self._characters)] for idx, character in enumerate(self._characters)]
self._flat_character_images = sum(self._character_images, []) self._flat_character_images: List[Tuple[str, int]] = sum(self._character_images, [])
def __len__(self): def __len__(self) -> int:
return len(self._flat_character_images) return len(self._flat_character_images)
def __getitem__(self, index): def __getitem__(self, index: int) -> Tuple[Any, Any]:
""" """
Args: Args:
index (int): Index index (int): Index
...@@ -71,13 +78,13 @@ class Omniglot(VisionDataset): ...@@ -71,13 +78,13 @@ class Omniglot(VisionDataset):
return image, character_class return image, character_class
def _check_integrity(self): def _check_integrity(self) -> bool:
zip_filename = self._get_target_folder() zip_filename = self._get_target_folder()
if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]): if not check_integrity(join(self.root, zip_filename + '.zip'), self.zips_md5[zip_filename]):
return False return False
return True return True
def download(self): def download(self) -> None:
if self._check_integrity(): if self._check_integrity():
print('Files already downloaded and verified') print('Files already downloaded and verified')
return return
...@@ -87,5 +94,5 @@ class Omniglot(VisionDataset): ...@@ -87,5 +94,5 @@ class Omniglot(VisionDataset):
url = self.download_url_prefix + '/' + zip_filename url = self.download_url_prefix + '/' + zip_filename
download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename]) download_and_extract_archive(url, self.root, filename=zip_filename, md5=self.zips_md5[filename])
def _get_target_folder(self): def _get_target_folder(self) -> str:
return 'images_background' if self.background else 'images_evaluation' return 'images_background' if self.background else 'images_evaluation'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment