Unverified Commit 4e5ee9f1 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing utils dataset (#522)



* add inline typing to utils Dataset

* add inline typing to common_utils

* add missing inline typing

* add typing to kwarg

* add missing inline typing

* update docstring

* undo indentation
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent a9c4d0a8
...@@ -4,7 +4,7 @@ PY3 = sys.version_info > (3, 0) ...@@ -4,7 +4,7 @@ PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4) PY34 = sys.version_info >= (3, 4)
def _check_module_exists(name): def _check_module_exists(name: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without** r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of `import X`. It avoids third party libraries breaking assumptions of some of
......
...@@ -7,7 +7,9 @@ import sys ...@@ -7,7 +7,9 @@ import sys
import tarfile import tarfile
import threading import threading
import zipfile import zipfile
from io import TextIOWrapper
from queue import Queue from queue import Queue
from typing import Any, Iterable, List, Optional, Tuple, Union
import torch import torch
import urllib import urllib
...@@ -15,12 +17,13 @@ from torch.utils.data import Dataset ...@@ -15,12 +17,13 @@ from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
def unicode_csv_reader(unicode_csv_data, **kwargs): def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> str:
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper. r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs: Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples https://docs.python.org/2/library/csv.html#csv-examples
Arguments: Args:
unicode_csv_data: unicode csv data (see example below) unicode_csv_data (TextIOWrapper): unicode csv data (see example below)
Examples: Examples:
>>> from torchaudio.datasets.utils import unicode_csv_reader >>> from torchaudio.datasets.utils import unicode_csv_reader
>>> import io >>> import io
...@@ -44,7 +47,7 @@ def unicode_csv_reader(unicode_csv_data, **kwargs): ...@@ -44,7 +47,7 @@ def unicode_csv_reader(unicode_csv_data, **kwargs):
yield line yield line
def makedir_exist_ok(dirpath): def makedir_exist_ok(dirpath: str) -> None:
""" """
Python2 support for os.makedirs(.., exist_ok=True) Python2 support for os.makedirs(.., exist_ok=True)
""" """
...@@ -57,14 +60,17 @@ def makedir_exist_ok(dirpath): ...@@ -57,14 +60,17 @@ def makedir_exist_ok(dirpath):
raise raise
def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True): def stream_url(url: str,
start_byte: Optional[int] = None,
block_size: int = 32 * 1024,
progress_bar: bool = True) -> None:
"""Stream url by chunk """Stream url by chunk
Args: Args:
url (str): Url. url (str): Url.
start_byte (Optional[int]): Start streaming at that point. start_byte (int, optional): Start streaming at that point (Default: ``None``).
block_size (int): Size of chunks to stream. block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
progress_bar (bool): Display a progress bar. progress_bar (bool, optional): Display a progress bar (Default: ``True``).
""" """
# If we already have the whole file, there is no need to download it again # If we already have the whole file, there is no need to download it again
...@@ -95,25 +101,23 @@ def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True): ...@@ -95,25 +101,23 @@ def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True):
pbar.update(len(chunk)) pbar.update(len(chunk))
def download_url( def download_url(url: str,
url, download_folder: str,
download_folder, filename: Optional[str] = None,
filename=None, hash_value: Optional[str] = None,
hash_value=None, hash_type: str = "sha256",
hash_type="sha256", progress_bar: bool = True,
progress_bar=True, resume: bool = False) -> None:
resume=False,
):
"""Download file to disk. """Download file to disk.
Args: Args:
url (str): Url. url (str): Url.
download_folder (str): Folder to download file. download_folder (str): Folder to download file.
filename (str): Name of downloaded file. If None, it is inferred from the url. filename (str, optional): Name of downloaded file. If None, it is inferred from the url (Default: ``None``).
hash_value (str): Hash for url. hash_value (str, optional): Hash for url (Default: ``None``).
hash_type (str): Hash type, among "sha256" and "md5". hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
progress_bar (bool): Display a progress bar. progress_bar (bool, optional): Display a progress bar (Default: ``True``).
resume (bool): Enable resuming download. resume (bool, optional): Enable resuming download (Default: ``False``).
""" """
req = urllib.request.Request(url, method="HEAD") req = urllib.request.Request(url, method="HEAD")
...@@ -157,13 +161,16 @@ def download_url( ...@@ -157,13 +161,16 @@ def download_url(
) )
def validate_file(file_obj, hash_value, hash_type="sha256"): def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
"""Validate a given file object with its hash. """Validate a given file object with its hash.
Args: Args:
file_obj: File object to read from. file_obj: File object to read from.
hash_value (str): Hash for url. hash_value (str): Hash for url.
hash_type (str): Hash type, among "sha256" and "md5". hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.
""" """
if hash_type == "sha256": if hash_type == "sha256":
...@@ -183,14 +190,16 @@ def validate_file(file_obj, hash_value, hash_type="sha256"): ...@@ -183,14 +190,16 @@ def validate_file(file_obj, hash_value, hash_type="sha256"):
return hash_func.hexdigest() == hash_value return hash_func.hexdigest() == hash_value
def extract_archive(from_path, to_path=None, overwrite=False): def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
"""Extract archive. """Extract archive.
Arguments: Args:
from_path: the path of the archive. from_path (str): the path of the archive.
to_path: the root path of the extraced files (directory of from_path) to_path (str, optional): the root path of the extraced files (directory of from_path) (Default: ``None``)
overwrite: overwrite existing files (False) overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns: Returns:
List of paths to extracted files even if not overwritten. list: List of paths to extracted files even if not overwritten.
Examples: Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz' >>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz' >>> from_path = './validation.tar.gz'
...@@ -237,7 +246,10 @@ def extract_archive(from_path, to_path=None, overwrite=False): ...@@ -237,7 +246,10 @@ def extract_archive(from_path, to_path=None, overwrite=False):
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.") raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
def walk_files(root, suffix, prefix=False, remove_suffix=False): def walk_files(root: str,
suffix: Union[str, Tuple[str]],
prefix: bool = False,
remove_suffix: bool = False) -> str:
"""List recursively all files ending with a suffix at a given root """List recursively all files ending with a suffix at a given root
Args: Args:
root (str): Path to directory whose folders need to be listed root (str): Path to directory whose folders need to be listed
...@@ -269,14 +281,14 @@ class _DiskCache(Dataset): ...@@ -269,14 +281,14 @@ class _DiskCache(Dataset):
Wrap a dataset so that, whenever a new item is returned, it is saved to disk. Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
""" """
def __init__(self, dataset, location=".cached"): def __init__(self, dataset: Dataset, location: str = ".cached") -> None:
self.dataset = dataset self.dataset = dataset
self.location = location self.location = location
self._id = id(self) self._id = id(self)
self._cache = [None] * len(dataset) self._cache = [None] * len(dataset)
def __getitem__(self, n): def __getitem__(self, n: int) -> Any:
if self._cache[n]: if self._cache[n]:
f = self._cache[n] f = self._cache[n]
return torch.load(f) return torch.load(f)
...@@ -291,11 +303,11 @@ class _DiskCache(Dataset): ...@@ -291,11 +303,11 @@ class _DiskCache(Dataset):
return item return item
def __len__(self): def __len__(self) -> int:
return len(self.dataset) return len(self.dataset)
def diskcache_iterator(dataset, location=".cached"): def diskcache_iterator(dataset: Dataset, location: str = ".cached") -> Dataset:
return _DiskCache(dataset, location) return _DiskCache(dataset, location)
...@@ -311,31 +323,31 @@ class _ThreadedIterator(threading.Thread): ...@@ -311,31 +323,31 @@ class _ThreadedIterator(threading.Thread):
class _End: class _End:
pass pass
def __init__(self, generator, maxsize): def __init__(self, generator: Iterable, maxsize: int) -> None:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.queue = Queue(maxsize) self.queue = Queue(maxsize)
self.generator = generator self.generator = generator
self.daemon = True self.daemon = True
self.start() self.start()
def run(self): def run(self) -> None:
for item in self.generator: for item in self.generator:
self.queue.put(item) self.queue.put(item)
self.queue.put(self._End) self.queue.put(self._End)
def __iter__(self): def __iter__(self) -> Any:
return self return self
def __next__(self): def __next__(self) -> Any:
next_item = self.queue.get() next_item = self.queue.get()
if next_item == self._End: if next_item == self._End:
raise StopIteration raise StopIteration
return next_item return next_item
# Required for Python 2.7 compatibility # Required for Python 2.7 compatibility
def next(self): def next(self) -> Any:
return self.__next__() return self.__next__()
def bg_iterator(iterable, maxsize): def bg_iterator(iterable: Iterable, maxsize: int) -> Any:
return _ThreadedIterator(iterable, maxsize=maxsize) return _ThreadedIterator(iterable, maxsize=maxsize)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment