Unverified Commit 8b029651 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

[WIP] add typehints to datasets (#2487)

* enable mypy for torchvision.datasets and ignore existing errors

* imagenet

* utils

* lint
parent f5bf9d52
......@@ -4,9 +4,9 @@ files = torchvision
show_error_codes = True
pretty = True
[mypy-torchvision.datasets.*]
;[mypy-torchvision.datasets.*]
ignore_errors = True
;ignore_errors = True
[mypy-torchvision.io.*]
......@@ -24,7 +24,30 @@ ignore_errors = True
ignore_errors = True
[mypy-PIL]
[mypy-PIL.*]
ignore_missing_imports = True
[mypy-numpy.*]
ignore_missing_imports = True
[mypy-scipy.*]
ignore_missing_imports = True
[mypy-pycocotools.*]
ignore_missing_imports = True
[mypy-lmdb.*]
ignore_missing_imports = True
[mypy-pandas.*]
ignore_missing_imports = True
[mypy-accimage.*]
ignore_missing_imports = True
......@@ -3,6 +3,7 @@ from contextlib import contextmanager
import os
import shutil
import tempfile
from typing import Any, Dict, List, Iterator, Optional, Tuple
import torch
from .folder import ImageFolder
from .utils import check_integrity, extract_archive, verify_str_arg
......@@ -37,7 +38,7 @@ class ImageNet(ImageFolder):
targets (list): The class_index value for each image in the dataset
"""
def __init__(self, root, split='train', download=None, **kwargs):
def __init__(self, root: str, split: str = 'train', download: Optional[str] = None, **kwargs: Any) -> None:
if download is True:
msg = ("The dataset is no longer publicly accessible. You need to "
"download the archives externally and place them in the root "
......@@ -64,7 +65,7 @@ class ImageNet(ImageFolder):
for idx, clss in enumerate(self.classes)
for cls in clss}
def parse_archives(self):
def parse_archives(self) -> None:
if not check_integrity(os.path.join(self.root, META_FILE)):
parse_devkit_archive(self.root)
......@@ -75,14 +76,14 @@ class ImageNet(ImageFolder):
parse_val_archive(self.root)
@property
def split_folder(self):
def split_folder(self) -> str:
return os.path.join(self.root, self.split)
def extra_repr(self):
def extra_repr(self) -> str:
return "Split: {split}".format(**self.__dict__)
def load_meta_file(root, file=None):
def load_meta_file(root: str, file: Optional[str] = None) -> Tuple[Dict[str, str], List[str]]:
if file is None:
file = META_FILE
file = os.path.join(root, file)
......@@ -95,14 +96,14 @@ def load_meta_file(root, file=None):
raise RuntimeError(msg.format(file, root))
def _verify_archive(root, file, md5):
def _verify_archive(root: str, file: str, md5: str) -> None:
if not check_integrity(os.path.join(root, file), md5):
msg = ("The archive {} is not present in the root directory or is corrupted. "
"You need to download it externally and place it in {}.")
raise RuntimeError(msg.format(file, root))
def parse_devkit_archive(root, file=None):
def parse_devkit_archive(root: str, file: Optional[str] = None) -> None:
"""Parse the devkit archive of the ImageNet2012 classification dataset and save
the meta information in a binary file.
......@@ -113,7 +114,7 @@ def parse_devkit_archive(root, file=None):
"""
import scipy.io as sio
def parse_meta_mat(devkit_root):
def parse_meta_mat(devkit_root: str) -> Tuple[Dict[int, str], Dict[str, str]]:
metafile = os.path.join(devkit_root, "data", "meta.mat")
meta = sio.loadmat(metafile, squeeze_me=True)['synsets']
nums_children = list(zip(*meta))[4]
......@@ -125,7 +126,7 @@ def parse_devkit_archive(root, file=None):
wnid_to_classes = {wnid: clss for wnid, clss in zip(wnids, classes)}
return idx_to_wnid, wnid_to_classes
def parse_val_groundtruth_txt(devkit_root):
def parse_val_groundtruth_txt(devkit_root: str) -> List[int]:
file = os.path.join(devkit_root, "data",
"ILSVRC2012_validation_ground_truth.txt")
with open(file, 'r') as txtfh:
......@@ -133,7 +134,7 @@ def parse_devkit_archive(root, file=None):
return [int(val_idx) for val_idx in val_idcs]
@contextmanager
def get_tmp_dir():
def get_tmp_dir() -> Iterator[str]:
tmp_dir = tempfile.mkdtemp()
try:
yield tmp_dir
......@@ -158,7 +159,7 @@ def parse_devkit_archive(root, file=None):
torch.save((wnid_to_classes, val_wnids), os.path.join(root, META_FILE))
def parse_train_archive(root, file=None, folder="train"):
def parse_train_archive(root: str, file: Optional[str] = None, folder: str = "train") -> None:
"""Parse the train images archive of the ImageNet2012 classification dataset and
prepare it for usage with the ImageNet dataset.
......@@ -184,7 +185,9 @@ def parse_train_archive(root, file=None, folder="train"):
extract_archive(archive, os.path.splitext(archive)[0], remove_finished=True)
def parse_val_archive(root, file=None, wnids=None, folder="val"):
def parse_val_archive(
root: str, file: Optional[str] = None, wnids: Optional[List[str]] = None, folder: str = "val"
) -> None:
"""Parse the validation images archive of the ImageNet2012 classification dataset
and prepare it for usage with the ImageNet dataset.
......
......@@ -343,7 +343,7 @@ class QMNIST(MNIST):
'test50k': 'test',
'nist': 'nist'
}
resources = {
resources = { # type: ignore[assignment]
'train': [('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-images-idx3-ubyte.gz',
'ed72d4157d28c017586c42bc6afe6370'),
('https://raw.githubusercontent.com/facebookresearch/qmnist/master/qmnist-train-labels-idx2-int.gz',
......
......@@ -4,13 +4,14 @@ import hashlib
import gzip
import errno
import tarfile
from typing import Any, Callable, List, Iterable, Optional, TypeVar
import zipfile
import torch
from torch.utils.model_zoo import tqdm
def gen_bar_updater():
def gen_bar_updater() -> Callable[[int, int, int], None]:
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
......@@ -22,7 +23,7 @@ def gen_bar_updater():
return bar_update
def calculate_md5(fpath, chunk_size=1024 * 1024):
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
md5 = hashlib.md5()
with open(fpath, 'rb') as f:
for chunk in iter(lambda: f.read(chunk_size), b''):
......@@ -30,11 +31,11 @@ def calculate_md5(fpath, chunk_size=1024 * 1024):
return md5.hexdigest()
def check_md5(fpath, md5, **kwargs):
def check_md5(fpath: str, md5: str, **kwargs: Any) -> bool:
return md5 == calculate_md5(fpath, **kwargs)
def check_integrity(fpath, md5=None):
def check_integrity(fpath: str, md5: Optional[str] = None) -> bool:
if not os.path.isfile(fpath):
return False
if md5 is None:
......@@ -42,7 +43,7 @@ def check_integrity(fpath, md5=None):
return check_md5(fpath, md5)
def download_url(url, root, filename=None, md5=None):
def download_url(url: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None) -> None:
"""Download a file from a url and place it in root.
Args:
......@@ -70,7 +71,7 @@ def download_url(url, root, filename=None, md5=None):
url, fpath,
reporthook=gen_bar_updater()
)
except (urllib.error.URLError, IOError) as e:
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
......@@ -86,7 +87,7 @@ def download_url(url, root, filename=None, md5=None):
raise RuntimeError("File not found or corrupted.")
def list_dir(root, prefix=False):
def list_dir(root: str, prefix: bool = False) -> List[str]:
"""List all directories at a given root
Args:
......@@ -101,7 +102,7 @@ def list_dir(root, prefix=False):
return directories
def list_files(root, suffix, prefix=False):
def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
"""List all files ending with a suffix at a given root
Args:
......@@ -118,11 +119,11 @@ def list_files(root, suffix, prefix=False):
return files
def _quota_exceeded(response: "requests.models.Response") -> bool:
def _quota_exceeded(response: "requests.models.Response") -> bool: # type: ignore[name-defined]
return "Google Drive - Quota exceeded" in response.text
def download_file_from_google_drive(file_id, root, filename=None, md5=None):
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
"""Download a Google Drive file from and place it in root.
Args:
......@@ -165,7 +166,7 @@ def download_file_from_google_drive(file_id, root, filename=None, md5=None):
_save_response_content(response, fpath)
def _get_confirm_token(response):
def _get_confirm_token(response: "requests.models.Response") -> Optional[str]: # type: ignore[name-defined]
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
......@@ -173,7 +174,9 @@ def _get_confirm_token(response):
return None
def _save_response_content(response, destination, chunk_size=32768):
def _save_response_content(
response: "requests.models.Response", destination: str, chunk_size: int = 32768, # type: ignore[name-defined]
) -> None:
with open(destination, "wb") as f:
pbar = tqdm(total=None)
progress = 0
......@@ -185,31 +188,31 @@ def _save_response_content(response, destination, chunk_size=32768):
pbar.close()
def _is_tarxz(filename):
def _is_tarxz(filename: str) -> bool:
return filename.endswith(".tar.xz")
def _is_tar(filename):
def _is_tar(filename: str) -> bool:
return filename.endswith(".tar")
def _is_targz(filename):
def _is_targz(filename: str) -> bool:
return filename.endswith(".tar.gz")
def _is_tgz(filename):
def _is_tgz(filename: str) -> bool:
return filename.endswith(".tgz")
def _is_gzip(filename):
def _is_gzip(filename: str) -> bool:
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
def _is_zip(filename):
def _is_zip(filename: str) -> bool:
return filename.endswith(".zip")
def extract_archive(from_path, to_path=None, remove_finished=False):
def extract_archive(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> None:
if to_path is None:
to_path = os.path.dirname(from_path)
......@@ -236,8 +239,14 @@ def extract_archive(from_path, to_path=None, remove_finished=False):
os.remove(from_path)
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
md5=None, remove_finished=False):
def download_and_extract_archive(
url: str,
download_root: str,
extract_root: Optional[str] = None,
filename: Optional[str] = None,
md5: Optional[str] = None,
remove_finished: bool = False,
) -> None:
download_root = os.path.expanduser(download_root)
if extract_root is None:
extract_root = download_root
......@@ -251,11 +260,16 @@ def download_and_extract_archive(url, download_root, extract_root=None, filename
extract_archive(archive, extract_root, remove_finished)
def iterable_to_str(iterable):
def iterable_to_str(iterable: Iterable) -> str:
return "'" + "', '".join([str(item) for item in iterable]) + "'"
def verify_str_arg(value, arg=None, valid_values=None, custom_msg=None):
T = TypeVar("T", str, bytes)
def verify_str_arg(
value: T, arg: Optional[str] = None, valid_values: Iterable[T] = None, custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes):
if arg is None:
msg = "Expected type str, but got type {type}."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment