"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c11d11d63dcfc5caff40d50b8becaf303e5d9ab2"
Unverified Commit 9d40302d authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Fix inline typing for mypy (#544)

* fix inline typing for mypy

* fix flake8

* change check position

* fix for py3.5

* fix for py3.5

* change to inline typing

* add inline typing
parent c53ceb84
...@@ -168,13 +168,13 @@ def save_encinfo(filepath: str, ...@@ -168,13 +168,13 @@ def save_encinfo(filepath: str,
# sox stores the sample rate as a float, though practically sample rates are almost always integers # sox stores the sample rate as a float, though practically sample rates are almost always integers
# convert integers to floats # convert integers to floats
if signalinfo: if signalinfo:
if not isinstance(signalinfo.rate, float): if signalinfo.rate and not isinstance(signalinfo.rate, float):
if float(signalinfo.rate) == signalinfo.rate: if float(signalinfo.rate) == signalinfo.rate:
signalinfo.rate = float(signalinfo.rate) signalinfo.rate = float(signalinfo.rate)
else: else:
raise TypeError('Sample rate should be a float or int') raise TypeError('Sample rate should be a float or int')
# check if the bit precision (i.e. bits per sample) is an integer # check if the bit precision (i.e. bits per sample) is an integer
if not isinstance(signalinfo.precision, int): if signalinfo.precision and not isinstance(signalinfo.precision, int):
if int(signalinfo.precision) == signalinfo.precision: if int(signalinfo.precision) == signalinfo.precision:
signalinfo.precision = int(signalinfo.precision) signalinfo.precision = int(signalinfo.precision)
else: else:
......
import os.path import os.path
from typing import Any, Optional, Tuple, Union from typing import Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -9,8 +9,8 @@ from torchaudio._soundfile_backend import SignalInfo, EncodingInfo ...@@ -9,8 +9,8 @@ from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
def load(filepath: str, def load(filepath: str,
out: Optional[Tensor] = None, out: Optional[Tensor] = None,
normalization: Optional[bool] = True, normalization: bool = True,
channels_first: Optional[bool] = True, channels_first: bool = True,
num_frames: int = 0, num_frames: int = 0,
offset: int = 0, offset: int = 0,
signalinfo: SignalInfo = None, signalinfo: SignalInfo = None,
......
...@@ -7,17 +7,18 @@ import sys ...@@ -7,17 +7,18 @@ import sys
import tarfile import tarfile
import threading import threading
import zipfile import zipfile
from io import TextIOWrapper from _io import TextIOWrapper
from queue import Queue from queue import Queue
from typing import Any, Iterable, List, Optional, Tuple, Union from typing import Any, Iterable, List, Optional, Tuple, Union
import torch import torch
import urllib import urllib
import urllib.request
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm from torch.utils.model_zoo import tqdm
def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> str: def unicode_csv_reader(unicode_csv_data: TextIOWrapper, **kwargs: Any) -> Any:
r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper. r"""Since the standard csv library does not handle unicode in Python 2, we need a wrapper.
Borrowed and slightly modified from the Python docs: Borrowed and slightly modified from the Python docs:
https://docs.python.org/2/library/csv.html#csv-examples https://docs.python.org/2/library/csv.html#csv-examples
...@@ -63,7 +64,7 @@ def makedir_exist_ok(dirpath: str) -> None: ...@@ -63,7 +64,7 @@ def makedir_exist_ok(dirpath: str) -> None:
def stream_url(url: str, def stream_url(url: str,
start_byte: Optional[int] = None, start_byte: Optional[int] = None,
block_size: int = 32 * 1024, block_size: int = 32 * 1024,
progress_bar: bool = True) -> None: progress_bar: bool = True) -> Iterable:
"""Stream url by chunk """Stream url by chunk
Args: Args:
...@@ -126,10 +127,10 @@ def download_url(url: str, ...@@ -126,10 +127,10 @@ def download_url(url: str,
# Detect filename # Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url) filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename) filepath = os.path.join(download_folder, filename)
if resume and os.path.exists(filepath): if resume and os.path.exists(filepath):
mode = "ab" mode = "ab"
local_size = os.path.getsize(filepath) local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath): elif not resume and os.path.exists(filepath):
raise RuntimeError( raise RuntimeError(
"{} already exists. Delete the file manually and retry.".format(filepath) "{} already exists. Delete the file manually and retry.".format(filepath)
...@@ -215,7 +216,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo ...@@ -215,7 +216,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
with tarfile.open(from_path, "r") as tar: with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file {}.".format(from_path)) logging.info("Opened tar file {}.".format(from_path))
files = [] files = []
for file_ in tar: for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name) file_path = os.path.join(to_path, file_.name)
if file_.isfile(): if file_.isfile():
files.append(file_path) files.append(file_path)
...@@ -249,7 +250,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo ...@@ -249,7 +250,7 @@ def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bo
def walk_files(root: str, def walk_files(root: str,
suffix: Union[str, Tuple[str]], suffix: Union[str, Tuple[str]],
prefix: bool = False, prefix: bool = False,
remove_suffix: bool = False) -> str: remove_suffix: bool = False) -> Iterable[str]:
"""List recursively all files ending with a suffix at a given root """List recursively all files ending with a suffix at a given root
Args: Args:
root (str): Path to directory whose folders need to be listed root (str): Path to directory whose folders need to be listed
...@@ -286,7 +287,7 @@ class _DiskCache(Dataset): ...@@ -286,7 +287,7 @@ class _DiskCache(Dataset):
self.location = location self.location = location
self._id = id(self) self._id = id(self)
self._cache = [None] * len(dataset) self._cache: List = [None] * len(dataset)
def __getitem__(self, n: int) -> Any: def __getitem__(self, n: int) -> Any:
if self._cache[n]: if self._cache[n]:
...@@ -325,7 +326,7 @@ class _ThreadedIterator(threading.Thread): ...@@ -325,7 +326,7 @@ class _ThreadedIterator(threading.Thread):
def __init__(self, generator: Iterable, maxsize: int) -> None: def __init__(self, generator: Iterable, maxsize: int) -> None:
threading.Thread.__init__(self) threading.Thread.__init__(self)
self.queue = Queue(maxsize) self.queue: Queue = Queue(maxsize)
self.generator = generator self.generator = generator
self.daemon = True self.daemon = True
self.start() self.start()
......
...@@ -76,7 +76,7 @@ class VCTK(Dataset): ...@@ -76,7 +76,7 @@ class VCTK(Dataset):
if downsample: if downsample:
warnings.warn( warnings.warn(
"In the next version, transforms will not be part of the dataset. " "In the next version, transforms will not be part of the dataset. "
"Please use `downsample=False` to enable this behavior now, ", "Please use `downsample=False` to enable this behavior now, "
"and suppress this warning." "and suppress this warning."
) )
......
...@@ -86,8 +86,8 @@ class SoxEffectsChain(object): ...@@ -86,8 +86,8 @@ class SoxEffectsChain(object):
out_siginfo: Any = None, out_siginfo: Any = None,
out_encinfo: Any = None, out_encinfo: Any = None,
filetype: str = "raw") -> None: filetype: str = "raw") -> None:
self.input_file = None self.input_file: Optional[str] = None
self.chain = [] self.chain: List[str] = []
self.MAX_EFFECT_OPTS = 20 self.MAX_EFFECT_OPTS = 20
self.out_siginfo = out_siginfo self.out_siginfo = out_siginfo
self.out_encinfo = out_encinfo self.out_encinfo = out_encinfo
...@@ -100,12 +100,12 @@ class SoxEffectsChain(object): ...@@ -100,12 +100,12 @@ class SoxEffectsChain(object):
def append_effect_to_chain(self, def append_effect_to_chain(self,
ename: str, ename: str,
eargs: Optional[List[str]] = None) -> None: eargs: Optional[Union[List[str], str]] = None) -> None:
r"""Append effect to a sox effects chain. r"""Append effect to a sox effects chain.
Args: Args:
ename (str): which is the name of effect ename (str): which is the name of effect
eargs (List[str], optional): which is a list of effect options. (Default: ``None``) eargs (List[str] or str, optional): which is a list of effect options. (Default: ``None``)
""" """
e = SoxEffect() e = SoxEffect()
# check if we have a valid effect # check if we have a valid effect
......
...@@ -454,7 +454,7 @@ class MFCC(torch.nn.Module): ...@@ -454,7 +454,7 @@ class MFCC(torch.nn.Module):
super(MFCC, self).__init__() super(MFCC, self).__init__()
supported_dct_types = [2] supported_dct_types = [2]
if dct_type not in supported_dct_types: if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported'.format(dct_type)) raise ValueError('DCT type not supported: {}'.format(dct_type))
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_mfcc = n_mfcc self.n_mfcc = n_mfcc
self.dct_type = dct_type self.dct_type = dct_type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment