Unverified Commit 4e5ee9f1 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing utils dataset (#522)



* add inline typing to utils Dataset

* add inline typing to common_utils

* add missing inline typing

* add typing to kwarg

* add missing inline typing

* update docstring

* undo indentation
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent a9c4d0a8
......@@ -4,7 +4,7 @@ PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4)
def _check_module_exists(name):
def _check_module_exists(name: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
......
......@@ -7,7 +7,9 @@ import sys
import tarfile
import threading
import zipfile
from io import TextIOWrapper
from queue import Queue
from typing import Any, Iterable, List, Optional, Tuple, Union
import torch
import urllib
......@@ -15,12 +17,13 @@ from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm
def unicode_csv_reader(unicode_csv_data, **kwargs):
def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> str:
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples
Arguments:
unicode_csv_data: unicode csv data (see example below)
Args:
unicode_csv_data (TextIOWrapper): unicode csv data (see example below)
Examples:
>>> from torchaudio.datasets.utils import unicode_csv_reader
>>> import io
......@@ -44,7 +47,7 @@ def unicode_csv_reader(unicode_csv_data, **kwargs):
yield line
def makedir_exist_ok(dirpath):
def makedir_exist_ok(dirpath: str) -> None:
"""
Python2 support for os.makedirs(.., exist_ok=True)
"""
......@@ -57,14 +60,17 @@ def makedir_exist_ok(dirpath):
raise
def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True):
def stream_url(url: str,
start_byte: Optional[int] = None,
block_size: int = 32 * 1024,
progress_bar: bool = True) -> None:
"""Stream url by chunk
Args:
url (str): Url.
start_byte (Optional[int]): Start streaming at that point.
block_size (int): Size of chunks to stream.
progress_bar (bool): Display a progress bar.
start_byte (int, optional): Start streaming at that point (Default: ``None``).
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
"""
# If we already have the whole file, there is no need to download it again
......@@ -95,25 +101,23 @@ def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True):
pbar.update(len(chunk))
def download_url(
url,
download_folder,
filename=None,
hash_value=None,
hash_type="sha256",
progress_bar=True,
resume=False,
):
def download_url(url: str,
download_folder: str,
filename: Optional[str] = None,
hash_value: Optional[str] = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False) -> None:
"""Download file to disk.
Args:
url (str): Url.
download_folder (str): Folder to download file.
filename (str): Name of downloaded file. If None, it is inferred from the url.
hash_value (str): Hash for url.
hash_type (str): Hash type, among "sha256" and "md5".
progress_bar (bool): Display a progress bar.
resume (bool): Enable resuming download.
filename (str, optional): Name of downloaded file. If None, it is inferred from the url (Default: ``None``).
hash_value (str, optional): Hash for url (Default: ``None``).
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
resume (bool, optional): Enable resuming download (Default: ``False``).
"""
req = urllib.request.Request(url, method="HEAD")
......@@ -157,13 +161,16 @@ def download_url(
)
def validate_file(file_obj, hash_value, hash_type="sha256"):
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
"""Validate a given file object with its hash.
Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str): Hash type, among "sha256" and "md5".
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.
"""
if hash_type == "sha256":
......@@ -183,14 +190,16 @@ def validate_file(file_obj, hash_value, hash_type="sha256"):
return hash_func.hexdigest() == hash_value
def extract_archive(from_path, to_path=None, overwrite=False):
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
"""Extract archive.
Arguments:
from_path: the path of the archive.
to_path: the root path of the extraced files (directory of from_path)
overwrite: overwrite existing files (False)
Args:
from_path (str): the path of the archive.
to_path (str, optional): the root path of the extraced files (directory of from_path) (Default: ``None``)
overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns:
List of paths to extracted files even if not overwritten.
list: List of paths to extracted files even if not overwritten.
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
......@@ -237,7 +246,10 @@ def extract_archive(from_path, to_path=None, overwrite=False):
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
def walk_files(root, suffix, prefix=False, remove_suffix=False):
def walk_files(root: str,
suffix: Union[str, Tuple[str]],
prefix: bool = False,
remove_suffix: bool = False) -> str:
"""List recursively all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
......@@ -269,14 +281,14 @@ class _DiskCache(Dataset):
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""
def __init__(self, dataset, location=".cached"):
def __init__(self, dataset: Dataset, location: str = ".cached") -> None:
self.dataset = dataset
self.location = location
self._id = id(self)
self._cache = [None] * len(dataset)
def __getitem__(self, n):
def __getitem__(self, n: int) -> Any:
if self._cache[n]:
f = self._cache[n]
return torch.load(f)
......@@ -291,11 +303,11 @@ class _DiskCache(Dataset):
return item
def __len__(self):
def __len__(self) -> int:
return len(self.dataset)
def diskcache_iterator(dataset, location=".cached"):
def diskcache_iterator(dataset: Dataset, location: str = ".cached") -> Dataset:
return _DiskCache(dataset, location)
......@@ -311,31 +323,31 @@ class _ThreadedIterator(threading.Thread):
class _End:
pass
def __init__(self, generator, maxsize):
def __init__(self, generator: Iterable, maxsize: int) -> None:
threading.Thread.__init__(self)
self.queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()
def run(self):
def run(self) -> None:
for item in self.generator:
self.queue.put(item)
self.queue.put(self._End)
def __iter__(self):
def __iter__(self) -> Any:
return self
def __next__(self):
def __next__(self) -> Any:
next_item = self.queue.get()
if next_item == self._End:
raise StopIteration
return next_item
# Required for Python 2.7 compatibility
def next(self):
def next(self) -> Any:
return self.__next__()
def bg_iterator(iterable, maxsize):
def bg_iterator(iterable: Iterable, maxsize: int) -> Any:
return _ThreadedIterator(iterable, maxsize=maxsize)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment