Unverified Commit 4c221140 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Add inline typing to functional (#482)

* add typing to functional

* fix minor things

* fix flake8
parent a72dd836
# -*- coding: utf-8 -*-
import math
from typing import Optional, Tuple
import torch
from torch import Tensor
__all__ = [
"istft",
......@@ -37,17 +39,16 @@ __all__ = [
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore
def _stft(
waveform,
n_fft,
hop_length,
win_length,
window,
center,
pad_mode,
normalized,
onesided,
):
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
waveform: Tensor,
n_fft: int,
hop_length: Optional[int],
win_length: Optional[int],
window: Optional[Tensor],
center: bool,
pad_mode: str,
normalized: bool,
onesided: bool
) -> Tensor:
return torch.stft(
waveform,
n_fft,
......@@ -62,18 +63,17 @@ def _stft(
def istft(
stft_matrix, # type: Tensor
n_fft, # type: int
hop_length=None, # type: Optional[int]
win_length=None, # type: Optional[int]
window=None, # type: Optional[Tensor]
center=True, # type: bool
pad_mode="reflect", # type: str
normalized=False, # type: bool
onesided=True, # type: bool
length=None, # type: Optional[int]
):
# type: (...) -> Tensor
stft_matrix: Tensor,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[Tensor] = None,
center: bool = True,
pad_mode: str = "reflect",
normalized: bool = False,
onesided: bool = True,
length: Optional[int] = None,
) -> Tensor:
r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
It has the same parameters (+ additional optional parameter of ``length``) and it should return the
least squares estimation of the original signal. The algorithm will check using the NOLA condition (
......@@ -103,26 +103,26 @@ def istft(
IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
stft_matrix (Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. It has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
hop_length (int or None, optional): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``)
win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
window (Optional[torch.Tensor]): The optional window function.
win_length (int or None, optional): The size of window frame and STFT filter. (Default: ``n_fft``)
window (Tensor or None, optional): The optional window function.
(Default: ``torch.ones(win_length)``)
center (bool): Whether ``input`` was padded on both sides so
center (bool, optional): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode (str): Controls the padding method used when ``center`` is True. (Default:
``'reflect'``)
normalized (bool): Whether the STFT was normalized. (Default: ``False``)
onesided (bool): Whether the STFT is onesided. (Default: ``True``)
length (Optional[int]): The amount to trim the signal by (i.e. the
pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default:
``"reflect"``)
normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
length (int or None, optional): The amount to trim the signal by (i.e. the
original signal length). (Default: whole signal)
Returns:
torch.Tensor: Least squares estimation of the original signal of size (..., signal_length)
Tensor: Least squares estimation of the original signal of size (..., signal_length)
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
......@@ -226,26 +226,32 @@ def istft(
def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized
):
# type: (Tensor, int, Tensor, int, int, int, Optional[float], bool) -> Tensor
waveform: Tensor,
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: Optional[float],
normalized: bool
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
waveform (Tensor): Tensor of audio of dimension (..., time)
pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
power (float): Exponent for the magnitude spectrogram,
power (float or None): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft
Returns:
torch.Tensor: Dimension (..., freq, time), freq is
Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
......@@ -275,9 +281,18 @@ def spectrogram(
def griffinlim(
specgram, window, n_fft, hop_length, win_length, power, normalized, n_iter, momentum, length, rand_init
):
# type: (Tensor, Tensor, int, int, int, float, bool, int, float, Optional[int], bool) -> Tensor
specgram: Tensor,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: float,
normalized: bool,
n_iter: int,
momentum: float,
length: Optional[int],
rand_init: bool
) -> Tensor:
r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
Implementation ported from `librosa`.
......@@ -295,22 +310,22 @@ def griffinlim(
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args:
specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
where freq is ``n_fft // 2 + 1``.
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
hop_length (int): Length of hop between STFT windows. (
Default: ``win_length // 2``)
win_length (int): Window size. (Default: ``n_fft``)
power (float): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
(must be > 0) e.g., 1 for energy, 2 for power, etc.
normalized (bool): Whether to normalize by magnitude after stft.
n_iter (int): Number of iteration for phase recovery process.
momentum (float): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
length (Optional[int]): Array length of the expected output. (Default: ``None``)
rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``)
Values near 1 can lead to faster convergence, but above 1 may not converge.
length (int or None): Array length of the expected output.
rand_init (bool): Initializes phase randomly if True, to zero otherwise.
Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
......@@ -331,7 +346,7 @@ def griffinlim(
else:
angles = torch.zeros(batch, freq, frames)
angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
.to(dtype=specgram.dtype, device=specgram.device)
.to(dtype=specgram.dtype, device=specgram.device)
specgram = specgram.unsqueeze(-1).expand_as(angles)
# And initialize the previous iterate to 0
......@@ -371,8 +386,13 @@ def griffinlim(
return waveform
def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
def amplitude_to_DB(
x: Tensor,
multiplier: float,
amin: float,
db_multiplier: float,
top_db: Optional[float] = None
) -> Tensor:
r"""Turn a tensor from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input tensor, and so
......@@ -380,15 +400,15 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
a full clip.
Args:
x (torch.Tensor): Input tensor before being converted to decibel scale
x (Tensor): Input tensor before being converted to decibel scale
multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp ``x``
db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
is 80. (Default: ``None``)
Returns:
torch.Tensor: Output tensor in decibel scale
Tensor: Output tensor in decibel scale
"""
x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
x_db -= multiplier * db_multiplier
......@@ -399,23 +419,31 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
return x_db
def DB_to_amplitude(x, ref, power):
# type: (Tensor, float, float) -> Tensor
def DB_to_amplitude(
x: Tensor,
ref: float,
power: float
) -> Tensor:
r"""Turn a tensor from the decibel scale to the power/amplitude scale.
Args:
x (torch.Tensor): Input tensor before being converted to power/amplitude scale.
x (Tensor): Input tensor before being converted to power/amplitude scale.
ref (float): Reference which the output will be scaled by.
power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.
Returns:
torch.Tensor: Output tensor in power/amplitude scale.
Tensor: Output tensor in power/amplitude scale.
"""
return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)
def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
# type: (int, float, float, int, int) -> Tensor
def create_fb_matrix(
n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int
) -> Tensor:
r"""Create a frequency bin conversion matrix.
Args:
......@@ -426,7 +454,7 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
sample_rate (int): Sample rate of the audio waveform
Returns:
torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be
......@@ -456,18 +484,21 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
return fb
def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor
def create_dct(
n_mfcc: int,
n_mels: int,
norm: Optional[str]
) -> Tensor:
r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
normalized depending on norm.
Args:
n_mfcc (int): Number of mfc coefficients to retain
n_mels (int): Number of mel filterbanks
norm (Optional[str]): Norm to use (either 'ortho' or None)
norm (str or None): Norm to use (either 'ortho' or None)
Returns:
torch.Tensor: The transformation matrix, to be right-multiplied to
Tensor: The transformation matrix, to be right-multiplied to
row-wise data of size (``n_mels``, ``n_mfcc``).
"""
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
......@@ -483,8 +514,10 @@ def create_dct(n_mfcc, n_mels, norm):
return dct.t()
def mu_law_encoding(x, quantization_channels):
# type: (Tensor, int) -> Tensor
def mu_law_encoding(
x: Tensor,
quantization_channels: int
) -> Tensor:
r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -492,11 +525,11 @@ def mu_law_encoding(x, quantization_channels):
returns a signal encoded with values from 0 to quantization_channels - 1.
Args:
x (torch.Tensor): Input tensor
x (Tensor): Input tensor
quantization_channels (int): Number of channels
Returns:
torch.Tensor: Input after mu-law encoding
Tensor: Input after mu-law encoding
"""
mu = quantization_channels - 1.0
if not x.is_floating_point():
......@@ -507,8 +540,10 @@ def mu_law_encoding(x, quantization_channels):
return x_mu
def mu_law_decoding(x_mu, quantization_channels):
# type: (Tensor, int) -> Tensor
def mu_law_decoding(
x_mu: Tensor,
quantization_channels: int
) -> Tensor:
r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -516,11 +551,11 @@ def mu_law_decoding(x_mu, quantization_channels):
and returns a signal scaled between -1 and 1.
Args:
x_mu (torch.Tensor): Input tensor
x_mu (Tensor): Input tensor
quantization_channels (int): Number of channels
Returns:
torch.Tensor: Input after mu-law decoding
Tensor: Input after mu-law decoding
"""
mu = quantization_channels - 1.0
if not x_mu.is_floating_point():
......@@ -531,65 +566,71 @@ def mu_law_decoding(x_mu, quantization_channels):
return x
def complex_norm(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tensor
def complex_norm(
complex_tensor: Tensor,
power: float = 1.0
) -> Tensor:
r"""Compute the norm of complex tensor input.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`).
Returns:
torch.Tensor: Power of the normed input tensor. Shape of `(..., )`
Tensor: Power of the normed input tensor. Shape of `(..., )`
"""
if power == 1.0:
return torch.norm(complex_tensor, 2, -1)
return torch.norm(complex_tensor, 2, -1).pow(power)
def angle(complex_tensor):
# type: (Tensor) -> Tensor
def angle(
complex_tensor: Tensor
) -> Tensor:
r"""Compute the angle of complex tensor input.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
Return:
torch.Tensor: Angle of a complex tensor. Shape of `(..., )`
Tensor: Angle of a complex tensor. Shape of `(..., )`
"""
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
def magphase(complex_tensor, power=1.0):
# type: (Tensor, float) -> Tuple[Tensor, Tensor]
def magphase(
complex_tensor: Tensor,
power: float = 1.0
) -> Tuple[Tensor, Tensor]:
r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)`
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex tensor
(Tensor, Tensor): The magnitude and phase of the complex tensor
"""
mag = complex_norm(complex_tensor, power)
phase = angle(complex_tensor)
return mag, phase
def phase_vocoder(complex_specgrams, rate, phase_advance):
# type: (Tensor, float, Tensor) -> Tensor
def phase_vocoder(
complex_specgrams: Tensor,
rate: float,
phase_advance: Tensor
) -> Tensor:
r"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of ``rate``.
Args:
complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)`
complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)`
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
of (freq, 1)
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(...,
freq, ceil(time/rate), complex=2)`
Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)`
Example
>>> freq, hop_length = 1025, 512
......@@ -650,22 +691,24 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
return complex_specgrams_stretch
def lfilter(waveform, a_coeffs, b_coeffs):
# type: (Tensor, Tensor, Tensor) -> Tensor
def lfilter(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (torch.Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`.
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[b0, b1, b2, ...]`.
Must be same size as a_coeffs (pad with 0's as necessary).
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`. Output will be clipped to -1 to 1.
Tensor: Waveform with dimension of `(..., time)`. Output will be clipped to -1 to 1.
"""
dim = waveform.dim()
......@@ -674,17 +717,17 @@ def lfilter(waveform, a_coeffs, b_coeffs):
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2)
assert(waveform.device == a_coeffs.device)
assert(b_coeffs.device == a_coeffs.device)
assert (a_coeffs.size(0) == b_coeffs.size(0))
assert (len(waveform.size()) == 2)
assert (waveform.device == a_coeffs.device)
assert (b_coeffs.device == a_coeffs.device)
device = waveform.device
dtype = waveform.dtype
n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(0)
n_sample_padded = n_sample + n_order - 1
assert(n_order > 0)
assert (n_order > 0)
# Pad the input and create output
padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
......@@ -720,13 +763,20 @@ def lfilter(waveform, a_coeffs, b_coeffs):
return output
def biquad(waveform, b0, b1, b2, a0, a1, a2):
# type: (Tensor, float, float, float, float, float, float) -> Tensor
def biquad(
waveform: Tensor,
b0: float,
b1: float,
b2: float,
a0: float,
a1: float,
a2: float
) -> Tensor:
r"""Perform a biquad filter of input tensor. Initial conditions set to 0.
https://en.wikipedia.org/wiki/Digital_biquad_filter
Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2]
......@@ -735,7 +785,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
a2 (float): denominator coefficient of current output y[n-2]
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform with dimension of `(..., time)`
"""
device = waveform.device
......@@ -749,23 +799,26 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
return output_waveform
def _dB2Linear(x):
# type: (float) -> float
def _dB2Linear(x: float) -> float:
return math.exp(x * math.log(10) / 20.0)
def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor
def highpass_biquad(
waveform: Tensor,
sample_rate: int,
cutoff_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform dimension of `(..., time)`
"""
GAIN = 1.
......@@ -783,18 +836,22 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor
def lowpass_biquad(
waveform: Tensor,
sample_rate: int,
cutoff_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
"""
GAIN = 1.
......@@ -812,18 +869,22 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def allpass_biquad(waveform, sample_rate, central_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor
def allpass_biquad(
waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design two-pole all-pass filter. Similar to SoX implementation.
Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
......@@ -841,20 +902,25 @@ def allpass_biquad(waveform, sample_rate, central_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def bandpass_biquad(waveform, sample_rate, central_freq, Q=0.707, const_skirt_gain=False):
# type: (Tensor, int, float, float, bool) -> Tensor
def bandpass_biquad(
waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707,
const_skirt_gain: bool = False
) -> Tensor:
r"""Design two-pole band-pass filter. Similar to SoX implementation.
Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
const_skirt_gain (bool) : If ``True``, uses a constant skirt gain (peak gain = Q).
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
const_skirt_gain (bool, optional) : If ``True``, uses a constant skirt gain (peak gain = Q).
If ``False``, uses a constant 0dB peak gain. (Default: ``False``)
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
......@@ -873,18 +939,22 @@ def bandpass_biquad(waveform, sample_rate, central_freq, Q=0.707, const_skirt_ga
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def bandreject_biquad(waveform, sample_rate, central_freq, Q=0.707):
# type: (Tensor, int, float, float) -> Tensor
def bandreject_biquad(
waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design two-pole band-reject filter. Similar to SoX implementation.
Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
......@@ -902,19 +972,24 @@ def bandreject_biquad(waveform, sample_rate, central_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
# type: (Tensor, int, float, float, float) -> Tensor
def equalizer_biquad(
waveform: Tensor,
sample_rate: int,
center_freq: float,
gain: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.
Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
center_freq (float): filter's central frequency
gain (float): desired gain at the boost (or attenuation) in dB
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
"""
w0 = 2 * math.pi * center_freq / sample_rate
A = math.exp(gain / 40.0 * math.log(10))
......@@ -929,21 +1004,26 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def band_biquad(waveform, sample_rate, central_freq, Q=0.707, noise=False):
# type: (Tensor, int, float, float, bool) -> Tensor
def band_biquad(
waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707,
noise: bool = False
) -> Tensor:
r"""Design two-pole band filter. Similar to SoX implementation.
Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
noise (bool) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion).
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
noise (bool, optional) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion).
If ``False``, uses mode oriented to pitched audio, i.e. voice, singing,
or instrumental music. (Default: ``False``)
or instrumental music (Default: ``False``).
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
......@@ -969,19 +1049,24 @@ def band_biquad(waveform, sample_rate, central_freq, Q=0.707, noise=False):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def treble_biquad(waveform, sample_rate, gain, central_freq=3000, Q=0.707):
# type: (Tensor, int, float, float, float) -> Tensor
def treble_biquad(
waveform: Tensor,
sample_rate: int,
gain: float,
central_freq: float = 3000,
Q: float = 0.707
) -> Tensor:
r"""Design a treble tone-control effect. Similar to SoX implementation.
Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
gain (float): desired gain at the boost (or attenuation) in dB.
central_freq (float): central frequency (in Hz). (Default: ``3000``)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor
central_freq (float, optional): central frequency (in Hz). (Default: ``3000``)
Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
......@@ -1005,16 +1090,18 @@ def treble_biquad(waveform, sample_rate, gain, central_freq=3000, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def deemph_biquad(waveform, sample_rate):
# type: (Tensor, int) -> Tensor
def deemph_biquad(
waveform: Tensor,
sample_rate: int
) -> Tensor:
r"""Apply ISO 908 CD de-emphasis (shelving) IIR filter. Similar to SoX implementation.
Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, Allowed sample rate ``44100`` or ``48000``
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
......@@ -1050,17 +1137,19 @@ def deemph_biquad(waveform, sample_rate):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def riaa_biquad(waveform, sample_rate):
# type: (Tensor, int) -> Tensor
def riaa_biquad(
waveform: Tensor,
sample_rate: int
) -> Tensor:
r"""Apply RIAA vinyl playback equalisation. Similar to SoX implementation.
Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz).
Allowed sample rates in Hz : ``44100``,``48000``,``88200``,``96000``
Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`
Tensor: Waveform of dimension of `(..., time)`
References:
http://sox.sourceforge.net/sox.html
......@@ -1102,7 +1191,7 @@ def riaa_biquad(waveform, sample_rate):
a_re = a0 + a1 * math.cos(-y) + a2 * math.cos(-2 * y)
b_im = b1 * math.sin(-y) + b2 * math.sin(-2 * y)
a_im = a1 * math.sin(-y) + a2 * math.sin(-2 * y)
g = 1 / math.sqrt((b_re**2 + b_im**2) / (a_re**2 + a_im**2))
g = 1 / math.sqrt((b_re ** 2 + b_im ** 2) / (a_re ** 2 + a_im ** 2))
b0 *= g
b1 *= g
......@@ -1111,8 +1200,12 @@ def riaa_biquad(waveform, sample_rate):
return biquad(waveform, b0, b1, b2, a0, a1, a2)
def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
def mask_along_axis_iid(
specgrams: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
......@@ -1125,7 +1218,7 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns:
torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
"""
if axis != 2 and axis != 3:
......@@ -1147,8 +1240,12 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
return specgrams
def mask_along_axis(specgram, mask_param, mask_value, axis):
# type: (Tensor, int, float, int) -> Tensor
def mask_along_axis(
specgram: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
......@@ -1161,7 +1258,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns:
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""
# pack batch
......@@ -1188,8 +1285,11 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
return specgram
def compute_deltas(specgram, win_length=5, mode="replicate"):
# type: (Tensor, int, str) -> Tensor
def compute_deltas(
specgram: Tensor,
win_length: int = 5,
mode: str = "replicate"
) -> Tensor:
r"""Compute delta coefficients of a tensor, usually a spectrogram:
.. math::
......@@ -1200,12 +1300,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
:math:`N` is (`win_length`-1)//2.
Args:
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
win_length (int): The window length used for computing delta
mode (str): Mode parameter passed to padding
specgram (Tensor): Tensor of audio of dimension (..., freq, time)
win_length (int, optional): The window length used for computing delta (Default: ``5``)
mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
Tensor: Tensor of deltas of dimension (..., freq, time)
Example
>>> specgram = torch.randn(1, 40, 1000)
......@@ -1226,11 +1326,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
kernel = (
torch
.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
.repeat(specgram.shape[1], 1, 1)
)
kernel = (torch.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype).repeat(specgram.shape[1], 1, 1))
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
......@@ -1240,16 +1336,18 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
return output
def gain(waveform, gain_db=1.0):
# type: (Tensor, float) -> Tensor
def gain(
waveform: Tensor,
gain_db: float = 1.0
) -> Tensor:
r"""Apply amplification or attenuation to the whole waveform.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
waveform (Tensor): Tensor of audio of dimension (..., time).
gain_db (float, optional) Gain adjustment in decibels (dB) (Default: ``1.0``).
Returns:
torch.Tensor: the whole waveform amplified by gain_db.
Tensor: the whole waveform amplified by gain_db.
"""
if (gain_db == 0):
return waveform
......@@ -1259,7 +1357,10 @@ def gain(waveform, gain_db=1.0):
return waveform * ratio
def _add_noise_shaping(dithered_waveform, waveform):
def _add_noise_shaping(
dithered_waveform: Tensor,
waveform: Tensor
) -> Tensor:
r"""Noise shaping is calculated by error:
error[n] = dithered[n] - original[n]
noise_shaped_waveform[n] = dithered[n] + error[n-1]
......@@ -1281,8 +1382,10 @@ def _add_noise_shaping(dithered_waveform, waveform):
return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:])
def _apply_probability_distribution(waveform, density_function="TPDF"):
# type: (Tensor, str) -> Tensor
def _apply_probability_distribution(
waveform: Tensor,
density_function: str = "TPDF"
) -> Tensor:
r"""Apply a probability distribution function on a waveform.
Triangular probability density function (TPDF) dither noise has a
......@@ -1297,14 +1400,14 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
The relationship of probabilities of results follows a bell-shaped,
or Gaussian curve, typical of dither generated by analog sources.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
probability_density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
waveform (Tensor): Tensor of audio of dimension (..., time)
probability_density_function (str, optional): The density function of a
continuous random variable (Default: ``"TPDF"``)
Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF`
Returns:
torch.Tensor: waveform dithered with TPDF
Tensor: waveform dithered with TPDF
"""
# pack batch
......@@ -1353,23 +1456,25 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:])
def dither(waveform, density_function="TPDF", noise_shaping=False):
# type: (Tensor, str, bool) -> Tensor
def dither(
waveform: Tensor,
density_function: str = "TPDF",
noise_shaping: bool = False
) -> Tensor:
r"""Dither increases the perceived dynamic range of audio stored at a
particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
waveform (Tensor): Tensor of audio of dimension (..., time)
density_function (str, optional): The density function of a continuous random variable (Default: ``"TPDF"``)
Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF`
noise_shaping (boolean): a filtering process that shapes the spectral
energy of quantisation error (Default: `False`)
noise_shaping (bool, optional): a filtering process that shapes the spectral
energy of quantisation error (Default: ``False``)
Returns:
torch.Tensor: waveform dithered
Tensor: waveform dithered
"""
dithered = _apply_probability_distribution(waveform, density_function=density_function)
......@@ -1379,8 +1484,12 @@ def dither(waveform, density_function="TPDF", noise_shaping=False):
return dithered
def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# type: (Tensor, int, float, int) -> Tensor
def _compute_nccf(
waveform: Tensor,
sample_rate: int,
frame_time: float,
freq_low: int
) -> Tensor:
r"""
Compute Normalized Cross-Correlation Function (NCCF).
......@@ -1390,7 +1499,7 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
where
:math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
:math:`w` is the waveform,
:math:`N` is the lenght of a frame,
:math:`N` is the length of a frame,
:math:`b_i` is the beginning of frame :math:`i`,
:math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
"""
......@@ -1411,12 +1520,8 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# Compute lags
output_lag = []
for lag in range(1, lags + 1):
s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[
..., :num_of_frames, :
]
s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[
..., :num_of_frames, :
]
s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
output_frames = (
(s1 * s2).sum(-1)
......@@ -1431,8 +1536,11 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
return nccf
def _combine_max(a, b, thresh=0.99):
# type: (Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], float) -> Tuple[Tensor, Tensor]
def _combine_max(
a: Tuple[Tensor, Tensor],
b: Tuple[Tensor, Tensor],
thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
"""
Take value from first if bigger than a multiplicative factor of the second, elementwise.
"""
......@@ -1442,8 +1550,11 @@ def _combine_max(a, b, thresh=0.99):
return values, indices
def _find_max_per_frame(nccf, sample_rate, freq_high):
# type: (Tensor, int, int) -> Tensor
def _find_max_per_frame(
nccf: Tensor,
sample_rate: int,
freq_high: int
) -> Tensor:
r"""
For each frame, take the highest value of NCCF,
apply centered median smoothing, and convert to frequency.
......@@ -1472,8 +1583,10 @@ def _find_max_per_frame(nccf, sample_rate, freq_high):
return indices
def _median_smoothing(indices, win_length):
# type: (Tensor, int) -> Tensor
def _median_smoothing(
indices: Tensor,
win_length: int
) -> Tensor:
r"""
Apply median smoothing to the 1D tensor over the given window.
"""
......@@ -1494,31 +1607,28 @@ def _median_smoothing(indices, win_length):
def detect_pitch_frequency(
waveform,
sample_rate,
frame_time=10 ** (-2),
win_length=30,
freq_low=85,
freq_high=3400,
):
# type: (Tensor, int, float, int, int, int) -> Tensor
waveform: Tensor,
sample_rate: int,
frame_time: float = 10 ** (-2),
win_length: int = 30,
freq_low: int = 85,
freq_high: int = 3400,
) -> Tensor:
r"""Detect pitch frequency.
It is implemented using normalized cross-correlation function and median smoothing.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time)
waveform (Tensor): Tensor of audio of dimension (..., freq, time)
sample_rate (int): The sample rate of the waveform (Hz)
win_length (int): The window length for median smoothing (in number of frames)
freq_low (int): Lowest frequency that can be detected (Hz)
freq_high (int): Highest frequency that can be detected (Hz)
frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Returns:
freq (torch.Tensor): Tensor of audio of dimension (..., frame)
Tensor: Tensor of freq of dimension (..., frame)
"""
dim = waveform.dim()
# pack batch
shape = list(waveform.size())
waveform = waveform.view([-1] + shape[-1:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment