Unverified Commit 2bce4898 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Inline typing transforms (#487)

* add inline typing

* update type hinting

* update typing

* sync the docstrings typing

* reorder imports, add typing to missing method

* add missing parenthesis
parent d069fb9f
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from warnings import warn
import math import math
from typing import Callable, Optional
from warnings import warn
import torch import torch
from typing import Optional from torch import Tensor
from torchaudio import functional as F from torchaudio import functional as F
from torchaudio.compliance import kaldi from torchaudio.compliance import kaldi
...@@ -35,19 +37,25 @@ class Spectrogram(torch.nn.Module): ...@@ -35,19 +37,25 @@ class Spectrogram(torch.nn.Module):
win_length (int or None, optional): Window size. (Default: ``n_fft``) win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``) pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float or None, optional): Exponent for the magnitude spectrogram, power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. (Default: ``2``) If None, then the complex spectrum is returned instead. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
""" """
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
def __init__(self, n_fft=400, win_length=None, hop_length=None, def __init__(self,
pad=0, window_fn=torch.hann_window, n_fft: int = 400,
power=2., normalized=False, wkwargs=None): win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: Optional[float] = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None) -> None:
super(Spectrogram, self).__init__() super(Spectrogram, self).__init__()
self.n_fft = n_fft self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1 # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
...@@ -60,13 +68,13 @@ class Spectrogram(torch.nn.Module): ...@@ -60,13 +68,13 @@ class Spectrogram(torch.nn.Module):
self.power = power self.power = power
self.normalized = normalized self.normalized = normalized
def forward(self, waveform): def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Dimension (..., freq, time), where freq is Tensor: Dimension (..., freq, time), where freq is
``n_fft // 2 + 1`` where ``n_fft`` is the number of ``n_fft // 2 + 1`` where ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame). Fourier bins, and time is the number of window hops (n_frame).
""" """
...@@ -96,12 +104,12 @@ class GriffinLim(torch.nn.Module): ...@@ -96,12 +104,12 @@ class GriffinLim(torch.nn.Module):
n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``) n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
win_length (int or None, optional): Window size. (Default: ``n_fft``) win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float, optional): Exponent for the magnitude spectrogram, power (float, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) wkwargs (dict or None, optional): Arguments for window function. (Default: ``None``)
momentum (float, optional): The momentum parameter for fast Griffin-Lim. momentum (float, optional): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method. Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``) Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
...@@ -111,9 +119,18 @@ class GriffinLim(torch.nn.Module): ...@@ -111,9 +119,18 @@ class GriffinLim(torch.nn.Module):
__constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized', __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
'length', 'momentum', 'rand_init'] 'length', 'momentum', 'rand_init']
def __init__(self, n_fft=400, n_iter=32, win_length=None, hop_length=None, def __init__(self,
window_fn=torch.hann_window, power=2., normalized=False, wkwargs=None, n_fft: int = 400,
momentum=0.99, length=None, rand_init=True): n_iter: int = 32,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
window_fn: Callable[..., Tensor] = torch.hann_window,
power: float = 2.,
normalized: bool = False,
wkwargs: Optional[dict] = None,
momentum: float = 0.99,
length: Optional[int] = None,
rand_init: bool = True) -> None:
super(GriffinLim, self).__init__() super(GriffinLim, self).__init__()
assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum
...@@ -131,8 +148,16 @@ class GriffinLim(torch.nn.Module): ...@@ -131,8 +148,16 @@ class GriffinLim(torch.nn.Module):
self.momentum = momentum / (1 + momentum) self.momentum = momentum / (1 + momentum)
self.rand_init = rand_init self.rand_init = rand_init
def forward(self, S): def forward(self, specgram: Tensor) -> Tensor:
return F.griffinlim(S, self.window, self.n_fft, self.hop_length, self.win_length, self.power, r"""
Args:
specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
where freq is ``n_fft // 2 + 1``.
Returns:
Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
"""
return F.griffinlim(specgram, self.window, self.n_fft, self.hop_length, self.win_length, self.power,
self.normalized, self.n_iter, self.momentum, self.length, self.rand_init) self.normalized, self.n_iter, self.momentum, self.length, self.rand_init)
...@@ -151,7 +176,7 @@ class AmplitudeToDB(torch.nn.Module): ...@@ -151,7 +176,7 @@ class AmplitudeToDB(torch.nn.Module):
""" """
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']
def __init__(self, stype='power', top_db=None): def __init__(self, stype: str = 'power', top_db: Optional[float] = None) -> None:
super(AmplitudeToDB, self).__init__() super(AmplitudeToDB, self).__init__()
self.stype = stype self.stype = stype
if top_db is not None and top_db < 0: if top_db is not None and top_db < 0:
...@@ -162,15 +187,15 @@ class AmplitudeToDB(torch.nn.Module): ...@@ -162,15 +187,15 @@ class AmplitudeToDB(torch.nn.Module):
self.ref_value = 1.0 self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) self.db_multiplier = math.log10(max(self.amin, self.ref_value))
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
r"""Numerically stable implementation from Librosa. r"""Numerically stable implementation from Librosa.
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args: Args:
x (torch.Tensor): Input tensor before being converted to decibel scale. x (Tensor): Input tensor before being converted to decibel scale.
Returns: Returns:
torch.Tensor: Output tensor in decibel scale. Tensor: Output tensor in decibel scale.
""" """
return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db) return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
...@@ -191,7 +216,12 @@ class MelScale(torch.nn.Module): ...@@ -191,7 +216,12 @@ class MelScale(torch.nn.Module):
""" """
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None): def __init__(self,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
n_stft: Optional[int] = None) -> None:
super(MelScale, self).__init__() super(MelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -204,13 +234,13 @@ class MelScale(torch.nn.Module): ...@@ -204,13 +234,13 @@ class MelScale(torch.nn.Module):
n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def forward(self, specgram): def forward(self, specgram: Tensor) -> Tensor:
r""" r"""
Args: Args:
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time). specgram (Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns: Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
""" """
# pack batch # pack batch
...@@ -242,20 +272,28 @@ class InverseMelScale(torch.nn.Module): ...@@ -242,20 +272,28 @@ class InverseMelScale(torch.nn.Module):
Args: Args:
n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`. n_stft (int): Number of bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
n_mels (int): Number of mel filterbanks. (Default: ``128``) n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int): Sample rate of audio signal. (Default: ``16000``) sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float): Minimum frequency. (Default: ``0.``) f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``) f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
max_iter (int): Maximum number of optimization iterations. max_iter (int, optional): Maximum number of optimization iterations. (Default: ``100000``)
tolerance_loss (float): Value of loss to stop optimization at. tolerance_loss (float, optional): Value of loss to stop optimization at. (Default: ``1e-5``)
tolerance_change (float): Difference in losses to stop optimization at. tolerance_change (float, optional): Difference in losses to stop optimization at. (Default: ``1e-8``)
sgdargs (dict): Arguments for the SGD optimizer. sgdargs (dict or None, optional): Arguments for the SGD optimizer. (Default: ``None``)
""" """
__constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss', __constants__ = ['n_stft', 'n_mels', 'sample_rate', 'f_min', 'f_max', 'max_iter', 'tolerance_loss',
'tolerance_change', 'sgdargs'] 'tolerance_change', 'sgdargs']
def __init__(self, n_stft, n_mels=128, sample_rate=16000, f_min=0., f_max=None, max_iter=100000, def __init__(self,
tolerance_loss=1e-5, tolerance_change=1e-8, sgdargs=None): n_stft: int,
n_mels: int = 128,
sample_rate: int = 16000,
f_min: float = 0.,
f_max: Optional[float] = None,
max_iter: int = 100000,
tolerance_loss: float = 1e-5,
tolerance_change: float = 1e-8,
sgdargs: Optional[dict] = None) -> None:
super(InverseMelScale, self).__init__() super(InverseMelScale, self).__init__()
self.n_mels = n_mels self.n_mels = n_mels
self.sample_rate = sample_rate self.sample_rate = sample_rate
...@@ -271,13 +309,13 @@ class InverseMelScale(torch.nn.Module): ...@@ -271,13 +309,13 @@ class InverseMelScale(torch.nn.Module):
fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate) fb = F.create_fb_matrix(n_stft, self.f_min, self.f_max, self.n_mels, self.sample_rate)
self.register_buffer('fb', fb) self.register_buffer('fb', fb)
def forward(self, melspec): def forward(self, melspec: Tensor) -> Tensor:
r""" r"""
Args: Args:
melspec (torch.Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time) melspec (Tensor): A Mel frequency spectrogram of dimension (..., ``n_mels``, time)
Returns: Returns:
torch.Tensor: Linear scale spectrogram of size (..., freq, time) Tensor: Linear scale spectrogram of size (..., freq, time)
""" """
# pack batch # pack batch
shape = melspec.size() shape = melspec.size()
...@@ -335,7 +373,7 @@ class MelSpectrogram(torch.nn.Module): ...@@ -335,7 +373,7 @@ class MelSpectrogram(torch.nn.Module):
f_max (float or None, optional): Maximum frequency. (Default: ``None``) f_max (float or None, optional): Maximum frequency. (Default: ``None``)
pad (int, optional): Two sided padding of signal. (Default: ``0``) pad (int, optional): Two sided padding of signal. (Default: ``0``)
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``) n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor window_fn (Callable[..., Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
...@@ -345,8 +383,17 @@ class MelSpectrogram(torch.nn.Module): ...@@ -345,8 +383,17 @@ class MelSpectrogram(torch.nn.Module):
""" """
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min'] __constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None, def __init__(self,
pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None): sample_rate: int = 16000,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
f_min: float = 0.,
f_max: Optional[float] = None,
pad: int = 0,
n_mels: int = 128,
window_fn: Callable[..., Tensor] = torch.hann_window,
wkwargs: Optional[dict] = None) -> None:
super(MelSpectrogram, self).__init__() super(MelSpectrogram, self).__init__()
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.n_fft = n_fft self.n_fft = n_fft
...@@ -362,13 +409,13 @@ class MelSpectrogram(torch.nn.Module): ...@@ -362,13 +409,13 @@ class MelSpectrogram(torch.nn.Module):
normalized=False, wkwargs=wkwargs) normalized=False, wkwargs=wkwargs)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1) self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
def forward(self, waveform): def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time). Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
""" """
specgram = self.spectrogram(waveform) specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram) mel_specgram = self.mel_scale(specgram)
...@@ -392,12 +439,17 @@ class MFCC(torch.nn.Module): ...@@ -392,12 +439,17 @@ class MFCC(torch.nn.Module):
dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``) dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``) norm (str, optional): norm to use. (Default: ``'ortho'``)
log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``) log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``) melkwargs (dict or None, optional): arguments for MelSpectrogram. (Default: ``None``)
""" """
__constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False, def __init__(self,
melkwargs=None): sample_rate: int = 16000,
n_mfcc: int = 40,
dct_type: int = 2,
norm: str = 'ortho',
log_mels: bool = False,
melkwargs: Optional[dict] = None) -> None:
super(MFCC, self).__init__() super(MFCC, self).__init__()
supported_dct_types = [2] supported_dct_types = [2]
if dct_type not in supported_dct_types: if dct_type not in supported_dct_types:
...@@ -420,13 +472,13 @@ class MFCC(torch.nn.Module): ...@@ -420,13 +472,13 @@ class MFCC(torch.nn.Module):
self.register_buffer('dct_mat', dct_mat) self.register_buffer('dct_mat', dct_mat)
self.log_mels = log_mels self.log_mels = log_mels
def forward(self, waveform): def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time). Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
""" """
# pack batch # pack batch
...@@ -461,17 +513,17 @@ class MuLawEncoding(torch.nn.Module): ...@@ -461,17 +513,17 @@ class MuLawEncoding(torch.nn.Module):
""" """
__constants__ = ['quantization_channels'] __constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256): def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawEncoding, self).__init__() super(MuLawEncoding, self).__init__()
self.quantization_channels = quantization_channels self.quantization_channels = quantization_channels
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
r""" r"""
Args: Args:
x (torch.Tensor): A signal to be encoded. x (Tensor): A signal to be encoded.
Returns: Returns:
x_mu (torch.Tensor): An encoded signal. x_mu (Tensor): An encoded signal.
""" """
return F.mu_law_encoding(x, self.quantization_channels) return F.mu_law_encoding(x, self.quantization_channels)
...@@ -488,17 +540,17 @@ class MuLawDecoding(torch.nn.Module): ...@@ -488,17 +540,17 @@ class MuLawDecoding(torch.nn.Module):
""" """
__constants__ = ['quantization_channels'] __constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256): def __init__(self, quantization_channels: int = 256) -> None:
super(MuLawDecoding, self).__init__() super(MuLawDecoding, self).__init__()
self.quantization_channels = quantization_channels self.quantization_channels = quantization_channels
def forward(self, x_mu): def forward(self, x_mu: Tensor) -> Tensor:
r""" r"""
Args: Args:
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded. x_mu (Tensor): A mu-law encoded signal which needs to be decoded.
Returns: Returns:
torch.Tensor: The signal decoded. Tensor: The signal decoded.
""" """
return F.mu_law_decoding(x_mu, self.quantization_channels) return F.mu_law_decoding(x_mu, self.quantization_channels)
...@@ -512,19 +564,22 @@ class Resample(torch.nn.Module): ...@@ -512,19 +564,22 @@ class Resample(torch.nn.Module):
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``) resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
""" """
def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'): def __init__(self,
orig_freq: int = 16000,
new_freq: int = 16000,
resampling_method: str = 'sinc_interpolation') -> None:
super(Resample, self).__init__() super(Resample, self).__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
self.new_freq = new_freq self.new_freq = new_freq
self.resampling_method = resampling_method self.resampling_method = resampling_method
def forward(self, waveform): def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Output signal of dimension (..., time). Tensor: Output signal of dimension (..., time).
""" """
if self.resampling_method == 'sinc_interpolation': if self.resampling_method == 'sinc_interpolation':
...@@ -550,11 +605,11 @@ class ComplexNorm(torch.nn.Module): ...@@ -550,11 +605,11 @@ class ComplexNorm(torch.nn.Module):
""" """
__constants__ = ['power'] __constants__ = ['power']
def __init__(self, power=1.0): def __init__(self, power: float = 1.0) -> None:
super(ComplexNorm, self).__init__() super(ComplexNorm, self).__init__()
self.power = power self.power = power
def forward(self, complex_tensor): def forward(self, complex_tensor: Tensor) -> Tensor:
r""" r"""
Args: Args:
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`. complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.
...@@ -576,18 +631,18 @@ class ComputeDeltas(torch.nn.Module): ...@@ -576,18 +631,18 @@ class ComputeDeltas(torch.nn.Module):
""" """
__constants__ = ['win_length'] __constants__ = ['win_length']
def __init__(self, win_length=5, mode="replicate"): def __init__(self, win_length: int = 5, mode: str = "replicate") -> None:
super(ComputeDeltas, self).__init__() super(ComputeDeltas, self).__init__()
self.win_length = win_length self.win_length = win_length
self.mode = mode self.mode = mode
def forward(self, specgram): def forward(self, specgram: Tensor) -> Tensor:
r""" r"""
Args: Args:
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time). specgram (Tensor): Tensor of audio of dimension (..., freq, time).
Returns: Returns:
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time). Tensor: Tensor of deltas of dimension (..., freq, time).
""" """
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
...@@ -603,7 +658,10 @@ class TimeStretch(torch.nn.Module): ...@@ -603,7 +658,10 @@ class TimeStretch(torch.nn.Module):
""" """
__constants__ = ['fixed_rate'] __constants__ = ['fixed_rate']
def __init__(self, hop_length=None, n_freq=201, fixed_rate=None): def __init__(self,
hop_length: Optional[int] = None,
n_freq: int = 201,
fixed_rate: Optional[float] = None) -> None:
super(TimeStretch, self).__init__() super(TimeStretch, self).__init__()
self.fixed_rate = fixed_rate self.fixed_rate = fixed_rate
...@@ -612,8 +670,7 @@ class TimeStretch(torch.nn.Module): ...@@ -612,8 +670,7 @@ class TimeStretch(torch.nn.Module):
hop_length = hop_length if hop_length is not None else n_fft // 2 hop_length = hop_length if hop_length is not None else n_fft // 2
self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None]) self.register_buffer('phase_advance', torch.linspace(0, math.pi * hop_length, n_freq)[..., None])
def forward(self, complex_specgrams, overriding_rate=None): def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = None) -> Tensor:
# type: (Tensor, Optional[float]) -> Tensor
r""" r"""
Args: Args:
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2). complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
...@@ -621,7 +678,7 @@ class TimeStretch(torch.nn.Module): ...@@ -621,7 +678,7 @@ class TimeStretch(torch.nn.Module):
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
Returns: Returns:
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2). Tensor: Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
""" """
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)" assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
...@@ -648,27 +705,28 @@ class Fade(torch.nn.Module): ...@@ -648,27 +705,28 @@ class Fade(torch.nn.Module):
fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine", fade_shape (str, optional): Shape of fade. Must be one of: "quarter_sine",
"half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``) "half_sine", "linear", "logarithmic", "exponential". (Default: ``"linear"``)
""" """
def __init__(self, fade_in_len=0, fade_out_len=0, fade_shape="linear"): def __init__(self,
fade_in_len: int = 0,
fade_out_len: int = 0,
fade_shape: str = "linear") -> None:
super(Fade, self).__init__() super(Fade, self).__init__()
self.fade_in_len = fade_in_len self.fade_in_len = fade_in_len
self.fade_out_len = fade_out_len self.fade_out_len = fade_out_len
self.fade_shape = fade_shape self.fade_shape = fade_shape
def forward(self, waveform): def forward(self, waveform: Tensor) -> Tensor:
# type: (Tensor) -> Tensor
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Tensor of audio of dimension (..., time). Tensor: Tensor of audio of dimension (..., time).
""" """
waveform_length = waveform.size()[-1] waveform_length = waveform.size()[-1]
return self._fade_in(waveform_length) * self._fade_out(waveform_length) * waveform return self._fade_in(waveform_length) * self._fade_out(waveform_length) * waveform
def _fade_in(self, waveform_length): def _fade_in(self, waveform_length: int) -> Tensor:
# type: (int) -> Tensor
fade = torch.linspace(0, 1, self.fade_in_len) fade = torch.linspace(0, 1, self.fade_in_len)
ones = torch.ones(waveform_length - self.fade_in_len) ones = torch.ones(waveform_length - self.fade_in_len)
...@@ -689,8 +747,7 @@ class Fade(torch.nn.Module): ...@@ -689,8 +747,7 @@ class Fade(torch.nn.Module):
return torch.cat((fade, ones)).clamp_(0, 1) return torch.cat((fade, ones)).clamp_(0, 1)
def _fade_out(self, waveform_length): def _fade_out(self, waveform_length: int) -> Tensor:
# type: (int) -> Tensor
fade = torch.linspace(0, 1, self.fade_out_len) fade = torch.linspace(0, 1, self.fade_out_len)
ones = torch.ones(waveform_length - self.fade_out_len) ones = torch.ones(waveform_length - self.fade_out_len)
...@@ -722,22 +779,21 @@ class _AxisMasking(torch.nn.Module): ...@@ -722,22 +779,21 @@ class _AxisMasking(torch.nn.Module):
""" """
__constants__ = ['mask_param', 'axis', 'iid_masks'] __constants__ = ['mask_param', 'axis', 'iid_masks']
def __init__(self, mask_param, axis, iid_masks): def __init__(self, mask_param: int, axis: int, iid_masks: bool) -> None:
super(_AxisMasking, self).__init__() super(_AxisMasking, self).__init__()
self.mask_param = mask_param self.mask_param = mask_param
self.axis = axis self.axis = axis
self.iid_masks = iid_masks self.iid_masks = iid_masks
def forward(self, specgram, mask_value=0.): def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
# type: (Tensor, float) -> Tensor
r""" r"""
Args: Args:
specgram (torch.Tensor): Tensor of dimension (..., freq, time). specgram (Tensor): Tensor of dimension (..., freq, time).
mask_value (float): Value to assign to the masked columns. mask_value (float): Value to assign to the masked columns.
Returns: Returns:
torch.Tensor: Masked spectrogram of dimensions (..., freq, time). Tensor: Masked spectrogram of dimensions (..., freq, time).
""" """
# if iid_masks flag marked and specgram has a batch dimension # if iid_masks flag marked and specgram has a batch dimension
...@@ -757,7 +813,7 @@ class FrequencyMasking(_AxisMasking): ...@@ -757,7 +813,7 @@ class FrequencyMasking(_AxisMasking):
the examples/channels in the batch. (Default: ``False``) the examples/channels in the batch. (Default: ``False``)
""" """
def __init__(self, freq_mask_param, iid_masks=False): def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
...@@ -771,7 +827,7 @@ class TimeMasking(_AxisMasking): ...@@ -771,7 +827,7 @@ class TimeMasking(_AxisMasking):
the examples/channels in the batch. (Default: ``False``) the examples/channels in the batch. (Default: ``False``)
""" """
def __init__(self, time_mask_param, iid_masks=False): def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
...@@ -786,7 +842,7 @@ class Vol(torch.nn.Module): ...@@ -786,7 +842,7 @@ class Vol(torch.nn.Module):
gain_type (str, optional): Type of gain. One of: ‘amplitude’, ‘power’, ‘db’ (Default: ``"amplitude"``) gain_type (str, optional): Type of gain. One of: ‘amplitude’, ‘power’, ‘db’ (Default: ``"amplitude"``)
""" """
def __init__(self, gain, gain_type='amplitude'): def __init__(self, gain: float, gain_type: str = 'amplitude'):
super(Vol, self).__init__() super(Vol, self).__init__()
self.gain = gain self.gain = gain
self.gain_type = gain_type self.gain_type = gain_type
...@@ -794,14 +850,13 @@ class Vol(torch.nn.Module): ...@@ -794,14 +850,13 @@ class Vol(torch.nn.Module):
if gain_type in ['amplitude', 'power'] and gain < 0: if gain_type in ['amplitude', 'power'] and gain < 0:
raise ValueError("If gain_type = amplitude or power, gain must be positive.") raise ValueError("If gain_type = amplitude or power, gain must be positive.")
def forward(self, waveform): def forward(self, waveform: Tensor) -> Tensor:
# type: (Tensor) -> Tensor
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Tensor of audio of dimension (..., time). Tensor: Tensor of audio of dimension (..., time).
""" """
if self.gain_type == "amplitude": if self.gain_type == "amplitude":
waveform = waveform * self.gain waveform = waveform * self.gain
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment