Unverified Commit 8b029651 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[WIP] add typehints to datasets (#2487)

* enable mypy for torchvision.datasets and ignore existing errors

* imagenet

* utils

* lint
parent f5bf9d52
...@@ -4,9 +4,9 @@ files = torchvision ...@@ -4,9 +4,9 @@ files = torchvision
show_error_codes = True show_error_codes = True
pretty = True pretty = True
[mypy-torchvision.datasets.*] ;[mypy-torchvision.datasets.*]
ignore_errors = True ;ignore_errors = True
[mypy-torchvision.io.*] [mypy-torchvision.io.*]
...@@ -24,7 +24,30 @@ ignore_errors = True ...@@ -24,7 +24,30 @@ ignore_errors = True
ignore_errors = True ignore_errors = True
[mypy-PIL] [mypy-PIL.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-scipy.*]
ignore_missing_imports = True
[mypy-pycocotools.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-lmdb.*]
ignore_missing_imports = True
[mypy-pandas.*]
ignore_missing_imports = True
[mypy-accimage.*]
ignore_missing_imports = True
...@@ -3,6 +3,7 @@ from contextlib import contextmanager ...@@ -3,6 +3,7 @@ from contextlib import contextmanager
import os import os
import shutil import shutil
import tempfile import tempfile
from typing import Any, Dict, List, Iterator, Optional, Tuple
import torch import torch
from .folder import ImageFolder from .folder import ImageFolder
from .utils import check_integrity, extract_archive, verify_str_arg from .utils import check_integrity, extract_archive, verify_str_arg
...@@ -37,7 +38,7 @@ class ImageNet(ImageFolder): ...@@ -37,7 +38,7 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset targets (list): The class_index value for each image in the dataset
""" """
def __init__(self, root, split='train', download=None, **kwargs): def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None:
if download is True: if download is True:
msg = ("The dataset is no longer publicly accessible. You need to " msg = ("The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root " "download the archives externally and place them in the root "
...@@ -64,7 +65,7 @@ class ImageNet(ImageFolder): ...@@ -64,7 +65,7 @@ class ImageNet(ImageFolder):
for idx, clss in enumerate(self.classes) for idx, clss in enumerate(self.classes)
for cls in clss} for cls in clss}
def parse_archives(self): def parse_archives(self) -> None:
if not check_integrity(os.path.join(self.root, META_FILE)): if not check_integrity(os.path.join(self.root, META_FILE)):
parse_devkit_archive(self.root) parse_devkit_archive(self.root)
...@@ -75,14 +76,14 @@ class ImageNet(ImageFolder): ...@@ -75,14 +76,14 @@ class ImageNet(ImageFolder):
parse_val_archive(self.root) parse_val_archive(self.root)
@property @property
def split_folder(self): def split_folder(self) -> str:
return os.path.join(self.root, self.split) return os.path.join(self.root, self.split)
def extra_repr(self): def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__) return "Split: {split}".format(**self.__dict__)
def load_meta_file(root, file=None): def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
if file is None: if file is None:
file = META_FILE file = META_FILE
file = os.path.join(root, file) file = os.path.join(root, file)
...@@ -95,14 +96,14 @@ def load_meta_file(root, file=None): ...@@ -95,14 +96,14 @@ def load_meta_file(root, file=None):
raise RuntimeError(msg.format(file, root)) raise RuntimeError(msg.format(file, root))
def _verify_archive(root, file, md5): def _verify_archive(root: str, file: str, md5: str) -> None:
if not check_integrity(os.path.join(root, file), md5): if not check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. " msg = ("The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}.") "You need to download it externally and place it in {}.")
raise RuntimeError(msg.format(file, root)) raise RuntimeError(msg.format(file, root))
def parse_devkit_archive(root, file=None): def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
"""Parse the devkit archive of the ImageNet2012 classification dataset and save """Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file. the meta information in a binary file.
...@@ -113,7 +114,7 @@ def parse_devkit_archive(root, file=None): ...@@ -113,7 +114,7 @@ def parse_devkit_archive(root, file=None):
""" """
import scipy.io as sio import scipy.io as sio
def parse_meta_mat(devkit_root): def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
metafile = os.path.join(devkit_root, "data", "meta.mat") metafile = os.path.join(devkit_root, "data", "meta.mat")
meta = sio.loadmat(metafile, squeeze_me=True)['synsets'] meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4] nums_children = list(zip(*meta))[4]
...@@ -125,7 +126,7 @@ def parse_devkit_archive(root, file=None): ...@@ -125,7 +126,7 @@ def parse_devkit_archive(root, file=None):
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)} wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth_txt(devkit_root): def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
file = os.path.join(devkit_root, "data", file = os.path.join(devkit_root, "data",
"ILSVRC2012_validation_ground_truth.txt") "ILSVRC2012_validation_ground_truth.txt")
with open(file, 'r') as txtfh: with open(file, 'r') as txtfh:
...@@ -133,7 +134,7 @@ def parse_devkit_archive(root, file=None): ...@@ -133,7 +134,7 @@ def parse_devkit_archive(root, file=None):
return [int(val_idx) for val_idx in val_idcs] return [int(val_idx) for val_idx in val_idcs]
@contextmanager @contextmanager
def get_tmp_dir(): def get_tmp_dir() -> Iterator[str]:
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
try: try:
yield tmp_dir yield tmp_dir
...@@ -158,7 +159,7 @@ def parse_devkit_archive(root, file=None): ...@@ -158,7 +159,7 @@ def parse_devkit_archive(root, file=None):
torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE)) torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
def parse_train_archive(root, file=None, folder="train"): def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
"""Parse the train images archive of the ImageNet2012 classification dataset and """Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset. prepare it for usage with the ImageNet dataset.
...@@ -184,7 +185,9 @@ def parse_train_archive(root, file=None, folder="train"): ...@@ -184,7 +185,9 @@ def parse_train_archive(root, file=None, folder="train"):
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True) extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
def parse_val_archive(root, file=None, wnids=None, folder="val"): def parse_val_archive(
root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
) -> None:
"""Parse the validation images archive of the ImageNet2012 classification dataset """Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset. and prepare it for usage with the ImageNet dataset.
......
...@@ -343,7 +343,7 @@ class QMNIST(MNIST): ...@@ -343,7 +343,7 @@ class QMNIST(MNIST):
'test50k': 'test', 'test50k': 'test',
'nist': 'nist' 'nist': 'nist'
} }
resources = { resources = { # type: ignore[assignment]
'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz', 'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'ed72d4157d28c017586c42bc6afe6370'), 'ed72d4157d28c017586c42bc6afe6370'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz', ('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz',
......
...@@ -4,13 +4,14 @@ import hashlib ...@@ -4,13 +4,14 @@ import hashlib
import gzip import gzip
import errno import errno
import tarfile import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
import zipfile import zipfile
import torch import torch
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
def gen_bar_updater(): def gen_bar_updater() -> Callable[[int, int, int], None]:
pbar = tqdm(total=None) pbar = tqdm(total=None)
def bar_update(count, block_size, total_size): def bar_update(count, block_size, total_size):
...@@ -22,7 +23,7 @@ def gen_bar_updater(): ...@@ -22,7 +23,7 @@ def gen_bar_updater():
return bar_update return bar_update
def calculate_md5(fpath, chunk_size=1024 * 1024): def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
md5 = hashlib.md5() md5 = hashlib.md5()
with open(fpath, 'rb') as f: with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''): for chunk in iter(lambda: f.read(chunk_size), b''):
...@@ -30,11 +31,11 @@ def calculate_md5(fpath, chunk_size=1024 * 1024): ...@@ -30,11 +31,11 @@ def calculate_md5(fpath, chunk_size=1024 * 1024):
return md5.hexdigest() return md5.hexdigest()
def check_md5(fpath, md5, **kwargs): def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs) return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None): def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath): if not os.path.isfile(fpath):
return False return False
if md5 is None: if md5 is None:
...@@ -42,7 +43,7 @@ def check_integrity(fpath, md5=None): ...@@ -42,7 +43,7 @@ def check_integrity(fpath, md5=None):
return check_md5(fpath, md5) return check_md5(fpath, md5)
def download_url(url, root, filename=None, md5=None): def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
"""Download a file from a url and place it in root. """Download a file from a url and place it in root.
Args: Args:
...@@ -70,7 +71,7 @@ def download_url(url, root, filename=None, md5=None): ...@@ -70,7 +71,7 @@ def download_url(url, root, filename=None, md5=None):
url, fpath, url, fpath,
reporthook=gen_bar_updater() reporthook=gen_bar_updater()
) )
except (urllib.error.URLError, IOError) as e: except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https': if url[:5] == 'https':
url = url.replace('https:', 'http:') url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.' print('Failed download. Trying https -> http instead.'
...@@ -86,7 +87,7 @@ def download_url(url, root, filename=None, md5=None): ...@@ -86,7 +87,7 @@ def download_url(url, root, filename=None, md5=None):
raise RuntimeError("File not found or corrupted.") raise RuntimeError("File not found or corrupted.")
def list_dir(root, prefix=False): def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root """List all directories at a given root
Args: Args:
...@@ -101,7 +102,7 @@ def list_dir(root, prefix=False): ...@@ -101,7 +102,7 @@ def list_dir(root, prefix=False):
return directories return directories
def list_files(root, suffix, prefix=False): def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root """List all files ending with a suffix at a given root
Args: Args:
...@@ -118,11 +119,11 @@ def list_files(root, suffix, prefix=False): ...@@ -118,11 +119,11 @@ def list_files(root, suffix, prefix=False):
return files return files
def _quota_exceeded(response: "requests.models.Response") -> bool: def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
return "Google Drive - Quota exceeded" in response.text return "Google Drive - Quota exceeded" in response.text
def download_file_from_google_drive(file_id, root, filename=None, md5=None): def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root. """Download a Google Drive file from and place it in root.
Args: Args:
...@@ -165,7 +166,7 @@ def download_file_from_google_drive(file_id, root, filename=None, md5=None): ...@@ -165,7 +166,7 @@ def download_file_from_google_drive(file_id, root, filename=None, md5=None):
_save_response_content(response, fpath) _save_response_content(response, fpath)
def _get_confirm_token(response): def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
for key, value in response.cookies.items(): for key, value in response.cookies.items():
if key.startswith('download_warning'): if key.startswith('download_warning'):
return value return value
...@@ -173,7 +174,9 @@ def _get_confirm_token(response): ...@@ -173,7 +174,9 @@ def _get_confirm_token(response):
return None return None
def _save_response_content(response, destination, chunk_size=32768): def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f: with open(destination, "wb") as f:
pbar = tqdm(total=None) pbar = tqdm(total=None)
progress = 0 progress = 0
...@@ -185,31 +188,31 @@ def _save_response_content(response, destination, chunk_size=32768): ...@@ -185,31 +188,31 @@ def _save_response_content(response, destination, chunk_size=32768):
pbar.close() pbar.close()
def _is_tarxz(filename): def _is_tarxz(filename: str) -> bool:
return filename.endswith(".tar.xz") return filename.endswith(".tar.xz")
def _is_tar(filename): def _is_tar(filename: str) -> bool:
return filename.endswith(".tar") return filename.endswith(".tar")
def _is_targz(filename): def _is_targz(filename: str) -> bool:
return filename.endswith(".tar.gz") return filename.endswith(".tar.gz")
def _is_tgz(filename): def _is_tgz(filename: str) -> bool:
return filename.endswith(".tgz") return filename.endswith(".tgz")
def _is_gzip(filename): def _is_gzip(filename: str) -> bool:
return filename.endswith(".gz") and not filename.endswith(".tar.gz") return filename.endswith(".gz") and not filename.endswith(".tar.gz")
def _is_zip(filename): def _is_zip(filename: str) -> bool:
return filename.endswith(".zip") return filename.endswith(".zip")
def extract_archive(from_path, to_path=None, remove_finished=False): def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
if to_path is None: if to_path is None:
to_path = os.path.dirname(from_path) to_path = os.path.dirname(from_path)
...@@ -236,8 +239,14 @@ def extract_archive(from_path, to_path=None, remove_finished=False): ...@@ -236,8 +239,14 @@ def extract_archive(from_path, to_path=None, remove_finished=False):
os.remove(from_path) os.remove(from_path)
def download_and_extract_archive(url, download_root, extract_root=None, filename=None, def download_and_extract_archive(
md5=None, remove_finished=False): url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root) download_root = os.path.expanduser(download_root)
if extract_root is None: if extract_root is None:
extract_root = download_root extract_root = download_root
...@@ -251,11 +260,16 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename ...@@ -251,11 +260,16 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename
extract_archive(archive, extract_root, remove_finished) extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable): def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'" return "'" + "', '".join([str(item) for item in iterable]) + "'"
def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None): T = TypeVar("T", str, bytes)
def verify_str_arg(
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes): if not isinstance(value, torch._six.string_classes):
if arg is None: if arg is None:
msg = "Expected type str, but got type {type}." msg = "Expected type str, but got type {type}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment