Unverified Commit 9d40302d authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Fix inline typing for mypy (#544)

* fix inline typing for mypy

* fix flake8

* change check position

* fix for py3.5

* fix for py3.5

* change to inline typing

* add inline typing
parent c53ceb84
......@@ -168,13 +168,13 @@ def save_encinfo(filepath: str,
# sox stores the sample rate as a float, though practically sample rates are almost always integers
# convert integers to floats
if signalinfo:
if not isinstance(signalinfo.rate, float):
if signalinfo.rate and not isinstance(signalinfo.rate, float):
if float(signalinfo.rate) == signalinfo.rate:
signalinfo.rate = float(signalinfo.rate)
else:
raise TypeError('Sample rate should be a float or int')
# check if the bit precision (i.e. bits per sample) is an integer
if not isinstance(signalinfo.precision, int):
if signalinfo.precision and not isinstance(signalinfo.precision, int):
if int(signalinfo.precision) == signalinfo.precision:
signalinfo.precision = int(signalinfo.precision)
else:
......
import os.path
from typing import Any, Optional, Tuple, Union
from typing import Optional, Tuple
import torch
from torch import Tensor
......@@ -9,8 +9,8 @@ from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
channels_first: Optional[bool] = True,
normalization: bool = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: SignalInfo = None,
......
......@@ -7,17 +7,18 @@ import sys
import tarfile
import threading
import zipfile
from io import TextIOWrapper
from _io import TextIOWrapper
from queue import Queue
from typing import Any, Iterable, List, Optional, Tuple, Union
import torch
import urllib
import urllib.request
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm
def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> str:
def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> Any:
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples
......@@ -63,7 +64,7 @@ def makedir_exist_ok(dirpath: str) -> None:
def stream_url(url: str,
start_byte: Optional[int] = None,
block_size: int = 32 * 1024,
progress_bar: bool = True) -> None:
progress_bar: bool = True) -> Iterable:
"""Stream url by chunk
Args:
......@@ -126,10 +127,10 @@ def download_url(url: str,
# Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename)
if resume and os.path.exists(filepath):
mode = "ab"
local_size = os.path.getsize(filepath)
local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError(
"{} already exists. Delete the file manually and retry.".format(filepath)
......@@ -215,7 +216,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file {}.".format(from_path))
files = []
for file_ in tar:
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
......@@ -249,7 +250,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
def walk_files(root: str,
suffix: Union[str, Tuple[str]],
prefix: bool = False,
remove_suffix: bool = False) -> str:
remove_suffix: bool = False) -> Iterable[str]:
"""List recursively all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
......@@ -286,7 +287,7 @@ class _DiskCache(Dataset):
self.location = location
self._id = id(self)
self._cache = [None] * len(dataset)
self._cache: List = [None] * len(dataset)
def __getitem__(self, n: int) -> Any:
if self._cache[n]:
......@@ -325,7 +326,7 @@ class _ThreadedIterator(threading.Thread):
def __init__(self, generator: Iterable, maxsize: int) -> None:
threading.Thread.__init__(self)
self.queue = Queue(maxsize)
self.queue: Queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()
......
......@@ -76,7 +76,7 @@ class VCTK(Dataset):
if downsample:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please use `downsample=False` to enable this behavior now, ",
"Please use `downsample=False` to enable this behavior now, "
"and suppress this warning."
)
......
......@@ -86,8 +86,8 @@ class SoxEffectsChain(object):
out_siginfo: Any = None,
out_encinfo: Any = None,
filetype: str = "raw") -> None:
self.input_file = None
self.chain = []
self.input_file: Optional[str] = None
self.chain: List[str] = []
self.MAX_EFFECT_OPTS = 20
self.out_siginfo = out_siginfo
self.out_encinfo = out_encinfo
......@@ -100,12 +100,12 @@ class SoxEffectsChain(object):
def append_effect_to_chain(self,
ename: str,
eargs: Optional[List[str]] = None) -> None:
eargs: Optional[Union[List[str], str]] = None) -> None:
r"""Append effect to a sox effects chain.
Args:
ename (str): which is the name of effect
eargs (List[str], optional): which is a list of effect options. (Default: ``None``)
eargs (List[str] or str, optional): which is a list of effect options. (Default: ``None``)
"""
e = SoxEffect()
# check if we have a valid effect
......
......@@ -454,7 +454,7 @@ class MFCC(torch.nn.Module):
super(MFCC, self).__init__()
supported_dct_types = [2]
if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported'.format(dct_type))
raise ValueError('DCT type not supported: {}'.format(dct_type))
self.sample_rate = sample_rate
self.n_mfcc = n_mfcc
self.dct_type = dct_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment