Unverified Commit e9f8ba9d authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Add inline typing to SoxEffects (#490)

* add inline typing

* correct typing and docstring

* remove inline typing Any on SoxEffect
parent c6bca702
import torch from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torchaudio import torchaudio
from torch import Tensor
from torchaudio._backend import _audio_backend_guard from torchaudio._backend import _audio_backend_guard
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def effect_names(): def effect_names() -> List[str]:
"""Gets list of valid sox effect names """Gets list of valid sox effect names
Returns: list[str] Returns: list[str]
...@@ -49,7 +50,7 @@ class SoxEffectsChain(object): ...@@ -49,7 +50,7 @@ class SoxEffectsChain(object):
automatically. . (Default: ``'raw'``) automatically. . (Default: ``'raw'``)
Returns: Returns:
Tuple[torch.Tensor, int]: An output Tensor of size `[C x L]` or `[L x C]` where L is the number Tuple[Tensor, int]: An output Tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file) audio (as listed in the metadata of the file)
...@@ -77,9 +78,14 @@ class SoxEffectsChain(object): ...@@ -77,9 +78,14 @@ class SoxEffectsChain(object):
""" """
EFFECTS_UNIMPLEMENTED = set(["spectrogram", "splice", "noiseprof", "fir"]) EFFECTS_UNIMPLEMENTED = {"spectrogram", "splice", "noiseprof", "fir"}
def __init__(self, normalization=True, channels_first=True, out_siginfo=None, out_encinfo=None, filetype="raw"): def __init__(self,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
out_siginfo: Any = None,
out_encinfo: Any = None,
filetype: str = "raw") -> None:
self.input_file = None self.input_file = None
self.chain = [] self.chain = []
self.MAX_EFFECT_OPTS = 20 self.MAX_EFFECT_OPTS = 20
...@@ -92,12 +98,14 @@ class SoxEffectsChain(object): ...@@ -92,12 +98,14 @@ class SoxEffectsChain(object):
# Define in __init__ to avoid calling at import time # Define in __init__ to avoid calling at import time
self.EFFECTS_AVAILABLE = set(effect_names()) self.EFFECTS_AVAILABLE = set(effect_names())
def append_effect_to_chain(self, ename, eargs=None): def append_effect_to_chain(self,
ename: str,
eargs: Optional[List[str]] = None) -> None:
r"""Append effect to a sox effects chain. r"""Append effect to a sox effects chain.
Args: Args:
ename (str): which is the name of effect ename (str): which is the name of effect
eargs (List[str]): which is a list of effect options. (Default: ``None``) eargs (List[str], optional): which is a list of effect options. (Default: ``None``)
""" """
e = SoxEffect() e = SoxEffect()
# check if we have a valid effect # check if we have a valid effect
...@@ -116,14 +124,15 @@ class SoxEffectsChain(object): ...@@ -116,14 +124,15 @@ class SoxEffectsChain(object):
self.chain.append(e) self.chain.append(e)
@_audio_backend_guard("sox") @_audio_backend_guard("sox")
def sox_build_flow_effects(self, out=None): def sox_build_flow_effects(self,
out: Optional[Tensor] = None) -> Tuple[Tensor, int]:
r"""Build effects chain and flow effects from input file to output tensor r"""Build effects chain and flow effects from input file to output tensor
Args: Args:
out (torch.Tensor): Where the output will be written to. (Default: ``None``) out (Tensor, optional): Where the output will be written to. (Default: ``None``)
Returns: Returns:
Tuple[torch.Tensor, int]: An output Tensor of size `[C x L]` or `[L x C]` where L is the number Tuple[Tensor, int]: An output Tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file) audio (as listed in the metadata of the file)
""" """
...@@ -154,12 +163,12 @@ class SoxEffectsChain(object): ...@@ -154,12 +163,12 @@ class SoxEffectsChain(object):
return out, sr return out, sr
def clear_chain(self): def clear_chain(self) -> None:
r"""Clear effects chain in python r"""Clear effects chain in python
""" """
self.chain = [] self.chain = []
def set_input_file(self, input_file): def set_input_file(self, input_file: str) -> None:
r"""Set input file for input of chain r"""Set input file for input of chain
Args: Args:
...@@ -167,7 +176,7 @@ class SoxEffectsChain(object): ...@@ -167,7 +176,7 @@ class SoxEffectsChain(object):
""" """
self.input_file = input_file self.input_file = input_file
def _check_effect(self, e): def _check_effect(self, e: str) -> str:
if e.lower() in self.EFFECTS_UNIMPLEMENTED: if e.lower() in self.EFFECTS_UNIMPLEMENTED:
raise NotImplementedError("This effect ({}) is not implement in torchaudio".format(e)) raise NotImplementedError("This effect ({}) is not implement in torchaudio".format(e))
elif e.lower() not in self.EFFECTS_AVAILABLE: elif e.lower() not in self.EFFECTS_AVAILABLE:
...@@ -176,7 +185,7 @@ class SoxEffectsChain(object): ...@@ -176,7 +185,7 @@ class SoxEffectsChain(object):
# https://stackoverflow.com/questions/12472338/flattening-a-list-recursively # https://stackoverflow.com/questions/12472338/flattening-a-list-recursively
# convenience function to flatten list recursively # convenience function to flatten list recursively
def _flatten(self, x): def _flatten(self, x: list) -> list:
if x == []: if x == []:
return [] return []
if isinstance(x[0], list): if isinstance(x[0], list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment