Unverified Commit 4c221140 authored by Tomás Osório's avatar Tomás Osório Committed by GitHub
Browse files

Add inline typing to functional (#482)

* add typing to functional

* fix minor things

* fix flake8
parent a72dd836
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import math import math
from typing import Optional, Tuple
import torch import torch
from torch import Tensor
__all__ = [ __all__ = [
"istft", "istft",
...@@ -37,17 +39,16 @@ __all__ = [ ...@@ -37,17 +39,16 @@ __all__ = [
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore @torch.jit.ignore
def _stft( def _stft(
waveform, waveform: Tensor,
n_fft, n_fft: int,
hop_length, hop_length: Optional[int],
win_length, win_length: Optional[int],
window, window: Optional[Tensor],
center, center: bool,
pad_mode, pad_mode: str,
normalized, normalized: bool,
onesided, onesided: bool
): ) -> Tensor:
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
return torch.stft( return torch.stft(
waveform, waveform,
n_fft, n_fft,
...@@ -62,18 +63,17 @@ def _stft( ...@@ -62,18 +63,17 @@ def _stft(
def istft( def istft(
stft_matrix, # type: Tensor stft_matrix: Tensor,
n_fft, # type: int n_fft: int,
hop_length=None, # type: Optional[int] hop_length: Optional[int] = None,
win_length=None, # type: Optional[int] win_length: Optional[int] = None,
window=None, # type: Optional[Tensor] window: Optional[Tensor] = None,
center=True, # type: bool center: bool = True,
pad_mode="reflect", # type: str pad_mode: str = "reflect",
normalized=False, # type: bool normalized: bool = False,
onesided=True, # type: bool onesided: bool = True,
length=None, # type: Optional[int] length: Optional[int] = None,
): ) -> Tensor:
# type: (...) -> Tensor
r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
It has the same parameters (+ additional optional parameter of ``length``) and it should return the It has the same parameters (+ additional optional parameter of ``length``) and it should return the
least squares estimation of the original signal. The algorithm will check using the NOLA condition ( least squares estimation of the original signal. The algorithm will check using the NOLA condition (
...@@ -103,26 +103,26 @@ def istft( ...@@ -103,26 +103,26 @@ def istft(
IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984. IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
Args: Args:
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each stft_matrix (Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. It has a size of either (..., fft_size, n_frame, 2) column is a window. It has a size of either (..., fft_size, n_frame, 2)
n_fft (int): Size of Fourier transform n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames. hop_length (int or None, optional): The distance between neighboring sliding window frames.
(Default: ``win_length // 4``) (Default: ``win_length // 4``)
win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``) win_length (int or None, optional): The size of window frame and STFT filter. (Default: ``n_fft``)
window (Optional[torch.Tensor]): The optional window function. window (Tensor or None, optional): The optional window function.
(Default: ``torch.ones(win_length)``) (Default: ``torch.ones(win_length)``)
center (bool): Whether ``input`` was padded on both sides so center (bool, optional): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``) (Default: ``True``)
pad_mode (str): Controls the padding method used when ``center`` is True. (Default: pad_mode (str, optional): Controls the padding method used when ``center`` is True. (Default:
``'reflect'``) ``"reflect"``)
normalized (bool): Whether the STFT was normalized. (Default: ``False``) normalized (bool, optional): Whether the STFT was normalized. (Default: ``False``)
onesided (bool): Whether the STFT is onesided. (Default: ``True``) onesided (bool, optional): Whether the STFT is onesided. (Default: ``True``)
length (Optional[int]): The amount to trim the signal by (i.e. the length (int or None, optional): The amount to trim the signal by (i.e. the
original signal length). (Default: whole signal) original signal length). (Default: whole signal)
Returns: Returns:
torch.Tensor: Least squares estimation of the original signal of size (..., signal_length) Tensor: Least squares estimation of the original signal of size (..., signal_length)
""" """
stft_matrix_dim = stft_matrix.dim() stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim)
...@@ -226,26 +226,32 @@ def istft( ...@@ -226,26 +226,32 @@ def istft(
def spectrogram( def spectrogram(
waveform, pad, window, n_fft, hop_length, win_length, power, normalized waveform: Tensor,
): pad: int,
# type: (Tensor, int, Tensor, int, int, int, Optional[float], bool) -> Tensor window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
power: Optional[float],
normalized: bool
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex. The spectrogram can be either magnitude-only or complex.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time) waveform (Tensor): Tensor of audio of dimension (..., time)
pad (int): Two sided padding of signal pad (int): Two sided padding of signal
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows hop_length (int): Length of hop between STFT windows
win_length (int): Window size win_length (int): Window size
power (float): Exponent for the magnitude spectrogram, power (float or None): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. If None, then the complex spectrum is returned instead.
normalized (bool): Whether to normalize by magnitude after stft normalized (bool): Whether to normalize by magnitude after stft
Returns: Returns:
torch.Tensor: Dimension (..., freq, time), freq is Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame). Fourier bins, and time is the number of window hops (n_frame).
""" """
...@@ -275,9 +281,18 @@ def spectrogram( ...@@ -275,9 +281,18 @@ def spectrogram(
def griffinlim( def griffinlim(
specgram, window, n_fft, hop_length, win_length, power, normalized, n_iter, momentum, length, rand_init specgram: Tensor,
): window: Tensor,
# type: (Tensor, Tensor, int, int, int, float, bool, int, float, Optional[int], bool) -> Tensor n_fft: int,
hop_length: int,
win_length: int,
power: float,
normalized: bool,
n_iter: int,
momentum: float,
length: Optional[int],
rand_init: bool
) -> Tensor:
r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation. r"""Compute waveform from a linear scale magnitude spectrogram using the Griffin-Lim transformation.
Implementation ported from `librosa`. Implementation ported from `librosa`.
...@@ -295,22 +310,22 @@ def griffinlim( ...@@ -295,22 +310,22 @@ def griffinlim(
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args: Args:
specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames) specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
where freq is ``n_fft // 2 + 1``. where freq is ``n_fft // 2 + 1``.
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
hop_length (int): Length of hop between STFT windows. ( hop_length (int): Length of hop between STFT windows. (
Default: ``win_length // 2``) Default: ``win_length // 2``)
win_length (int): Window size. (Default: ``n_fft``) win_length (int): Window size. (Default: ``n_fft``)
power (float): Exponent for the magnitude spectrogram, power (float): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) (must be > 0) e.g., 1 for energy, 2 for power, etc.
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) normalized (bool): Whether to normalize by magnitude after stft.
n_iter (int): Number of iteration for phase recovery process. n_iter (int): Number of iteration for phase recovery process.
momentum (float): The momentum parameter for fast Griffin-Lim. momentum (float): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method. Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) Values near 1 can lead to faster convergence, but above 1 may not converge.
length (Optional[int]): Array length of the expected output. (Default: ``None``) length (int or None): Array length of the expected output.
rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``) rand_init (bool): Initializes phase randomly if True, to zero otherwise.
Returns: Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given. torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
...@@ -331,7 +346,7 @@ def griffinlim( ...@@ -331,7 +346,7 @@ def griffinlim(
else: else:
angles = torch.zeros(batch, freq, frames) angles = torch.zeros(batch, freq, frames)
angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \ angles = torch.stack([angles.cos(), angles.sin()], dim=-1) \
.to(dtype=specgram.dtype, device=specgram.device) .to(dtype=specgram.dtype, device=specgram.device)
specgram = specgram.unsqueeze(-1).expand_as(angles) specgram = specgram.unsqueeze(-1).expand_as(angles)
# And initialize the previous iterate to 0 # And initialize the previous iterate to 0
...@@ -371,8 +386,13 @@ def griffinlim( ...@@ -371,8 +386,13 @@ def griffinlim(
return waveform return waveform
def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): def amplitude_to_DB(
# type: (Tensor, float, float, float, Optional[float]) -> Tensor x: Tensor,
multiplier: float,
amin: float,
db_multiplier: float,
top_db: Optional[float] = None
) -> Tensor:
r"""Turn a tensor from the power/amplitude scale to the decibel scale. r"""Turn a tensor from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input tensor, and so This output depends on the maximum value in the input tensor, and so
...@@ -380,15 +400,15 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): ...@@ -380,15 +400,15 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
a full clip. a full clip.
Args: Args:
x (torch.Tensor): Input tensor before being converted to decibel scale x (Tensor): Input tensor before being converted to decibel scale
multiplier (float): Use 10. for power and 20. for amplitude multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp ``x`` amin (float): Number to clamp ``x``
db_multiplier (float): Log10(max(reference value and amin)) db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number top_db (float or None, optional): Minimum negative cut-off in decibels. A reasonable number
is 80. (Default: ``None``) is 80. (Default: ``None``)
Returns: Returns:
torch.Tensor: Output tensor in decibel scale Tensor: Output tensor in decibel scale
""" """
x_db = multiplier * torch.log10(torch.clamp(x, min=amin)) x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
x_db -= multiplier * db_multiplier x_db -= multiplier * db_multiplier
...@@ -399,23 +419,31 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): ...@@ -399,23 +419,31 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
return x_db return x_db
def DB_to_amplitude(x, ref, power): def DB_to_amplitude(
# type: (Tensor, float, float) -> Tensor x: Tensor,
ref: float,
power: float
) -> Tensor:
r"""Turn a tensor from the decibel scale to the power/amplitude scale. r"""Turn a tensor from the decibel scale to the power/amplitude scale.
Args: Args:
x (torch.Tensor): Input tensor before being converted to power/amplitude scale. x (Tensor): Input tensor before being converted to power/amplitude scale.
ref (float): Reference which the output will be scaled by. ref (float): Reference which the output will be scaled by.
power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude. power (float): If power equals 1, will compute DB to power. If 0.5, will compute DB to amplitude.
Returns: Returns:
torch.Tensor: Output tensor in power/amplitude scale. Tensor: Output tensor in power/amplitude scale.
""" """
return ref * torch.pow(torch.pow(10.0, 0.1 * x), power) return ref * torch.pow(torch.pow(10.0, 0.1 * x), power)
def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): def create_fb_matrix(
# type: (int, float, float, int, int) -> Tensor n_freqs: int,
f_min: float,
f_max: float,
n_mels: int,
sample_rate: int
) -> Tensor:
r"""Create a frequency bin conversion matrix. r"""Create a frequency bin conversion matrix.
Args: Args:
...@@ -426,7 +454,7 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): ...@@ -426,7 +454,7 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
sample_rate (int): Sample rate of the audio waveform sample_rate (int): Sample rate of the audio waveform
Returns: Returns:
torch.Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``) Tensor: Triangular filter banks (fb matrix) of size (``n_freqs``, ``n_mels``)
meaning number of frequencies to highlight/apply to x the number of filterbanks. meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of Each column is a filterbank so that assuming there is a matrix A of
size (..., ``n_freqs``), the applied result would be size (..., ``n_freqs``), the applied result would be
...@@ -456,18 +484,21 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate): ...@@ -456,18 +484,21 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate):
return fb return fb
def create_dct(n_mfcc, n_mels, norm): def create_dct(
# type: (int, int, Optional[str]) -> Tensor n_mfcc: int,
n_mels: int,
norm: Optional[str]
) -> Tensor:
r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``), r"""Create a DCT transformation matrix with shape (``n_mels``, ``n_mfcc``),
normalized depending on norm. normalized depending on norm.
Args: Args:
n_mfcc (int): Number of mfc coefficients to retain n_mfcc (int): Number of mfc coefficients to retain
n_mels (int): Number of mel filterbanks n_mels (int): Number of mel filterbanks
norm (Optional[str]): Norm to use (either 'ortho' or None) norm (str or None): Norm to use (either 'ortho' or None)
Returns: Returns:
torch.Tensor: The transformation matrix, to be right-multiplied to Tensor: The transformation matrix, to be right-multiplied to
row-wise data of size (``n_mels``, ``n_mfcc``). row-wise data of size (``n_mels``, ``n_mfcc``).
""" """
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II # http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
...@@ -483,8 +514,10 @@ def create_dct(n_mfcc, n_mels, norm): ...@@ -483,8 +514,10 @@ def create_dct(n_mfcc, n_mels, norm):
return dct.t() return dct.t()
def mu_law_encoding(x, quantization_channels): def mu_law_encoding(
# type: (Tensor, int) -> Tensor x: Tensor,
quantization_channels: int
) -> Tensor:
r"""Encode signal based on mu-law companding. For more info see the r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -492,11 +525,11 @@ def mu_law_encoding(x, quantization_channels): ...@@ -492,11 +525,11 @@ def mu_law_encoding(x, quantization_channels):
returns a signal encoded with values from 0 to quantization_channels - 1. returns a signal encoded with values from 0 to quantization_channels - 1.
Args: Args:
x (torch.Tensor): Input tensor x (Tensor): Input tensor
quantization_channels (int): Number of channels quantization_channels (int): Number of channels
Returns: Returns:
torch.Tensor: Input after mu-law encoding Tensor: Input after mu-law encoding
""" """
mu = quantization_channels - 1.0 mu = quantization_channels - 1.0
if not x.is_floating_point(): if not x.is_floating_point():
...@@ -507,8 +540,10 @@ def mu_law_encoding(x, quantization_channels): ...@@ -507,8 +540,10 @@ def mu_law_encoding(x, quantization_channels):
return x_mu return x_mu
def mu_law_decoding(x_mu, quantization_channels): def mu_law_decoding(
# type: (Tensor, int) -> Tensor x_mu: Tensor,
quantization_channels: int
) -> Tensor:
r"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -516,11 +551,11 @@ def mu_law_decoding(x_mu, quantization_channels): ...@@ -516,11 +551,11 @@ def mu_law_decoding(x_mu, quantization_channels):
and returns a signal scaled between -1 and 1. and returns a signal scaled between -1 and 1.
Args: Args:
x_mu (torch.Tensor): Input tensor x_mu (Tensor): Input tensor
quantization_channels (int): Number of channels quantization_channels (int): Number of channels
Returns: Returns:
torch.Tensor: Input after mu-law decoding Tensor: Input after mu-law decoding
""" """
mu = quantization_channels - 1.0 mu = quantization_channels - 1.0
if not x_mu.is_floating_point(): if not x_mu.is_floating_point():
...@@ -531,65 +566,71 @@ def mu_law_decoding(x_mu, quantization_channels): ...@@ -531,65 +566,71 @@ def mu_law_decoding(x_mu, quantization_channels):
return x return x
def complex_norm(complex_tensor, power=1.0): def complex_norm(
# type: (Tensor, float) -> Tensor complex_tensor: Tensor,
power: float = 1.0
) -> Tensor:
r"""Compute the norm of complex tensor input. r"""Compute the norm of complex tensor input.
Args: Args:
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`). power (float): Power of the norm. (Default: `1.0`).
Returns: Returns:
torch.Tensor: Power of the normed input tensor. Shape of `(..., )` Tensor: Power of the normed input tensor. Shape of `(..., )`
""" """
if power == 1.0: if power == 1.0:
return torch.norm(complex_tensor, 2, -1) return torch.norm(complex_tensor, 2, -1)
return torch.norm(complex_tensor, 2, -1).pow(power) return torch.norm(complex_tensor, 2, -1).pow(power)
def angle(complex_tensor): def angle(
# type: (Tensor) -> Tensor complex_tensor: Tensor
) -> Tensor:
r"""Compute the angle of complex tensor input. r"""Compute the angle of complex tensor input.
Args: Args:
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
Return: Return:
torch.Tensor: Angle of a complex tensor. Shape of `(..., )` Tensor: Angle of a complex tensor. Shape of `(..., )`
""" """
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
def magphase(complex_tensor, power=1.0): def magphase(
# type: (Tensor, float) -> Tuple[Tensor, Tensor] complex_tensor: Tensor,
power: float = 1.0
) -> Tuple[Tensor, Tensor]:
r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase. r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase.
Args: Args:
complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
power (float): Power of the norm. (Default: `1.0`) power (float): Power of the norm. (Default: `1.0`)
Returns: Returns:
Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex tensor (Tensor, Tensor): The magnitude and phase of the complex tensor
""" """
mag = complex_norm(complex_tensor, power) mag = complex_norm(complex_tensor, power)
phase = angle(complex_tensor) phase = angle(complex_tensor)
return mag, phase return mag, phase
def phase_vocoder(complex_specgrams, rate, phase_advance): def phase_vocoder(
# type: (Tensor, float, Tensor) -> Tensor complex_specgrams: Tensor,
rate: float,
phase_advance: Tensor
) -> Tensor:
r"""Given a STFT tensor, speed up in time without modifying pitch by a r"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of ``rate``. factor of ``rate``.
Args: Args:
complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)` complex_specgrams (Tensor): Dimension of `(..., freq, time, complex=2)`
rate (float): Speed-up factor rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
of (freq, 1)
Returns: Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(..., Tensor: Complex Specgrams Stretch with dimension of `(..., freq, ceil(time/rate), complex=2)`
freq, ceil(time/rate), complex=2)`
Example Example
>>> freq, hop_length = 1025, 512 >>> freq, hop_length = 1025, 512
...@@ -650,22 +691,24 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -650,22 +691,24 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
return complex_specgrams_stretch return complex_specgrams_stretch
def lfilter(waveform, a_coeffs, b_coeffs): def lfilter(
# type: (Tensor, Tensor, Tensor) -> Tensor waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor
) -> Tensor:
r"""Perform an IIR filter by evaluating difference equation. r"""Perform an IIR filter by evaluating difference equation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1. waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`. a_coeffs (Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`. Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`.
Must be same size as b_coeffs (pad with 0's as necessary). Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (torch.Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`. b_coeffs (Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`.
Lower delays coefficients are first, e.g. `[b0, b1, b2, ...]`. Lower delays coefficients are first, e.g. `[b0, b1, b2, ...]`.
Must be same size as a_coeffs (pad with 0's as necessary). Must be same size as a_coeffs (pad with 0's as necessary).
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)`. Output will be clipped to -1 to 1. Tensor: Waveform with dimension of `(..., time)`. Output will be clipped to -1 to 1.
""" """
dim = waveform.dim() dim = waveform.dim()
...@@ -674,17 +717,17 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -674,17 +717,17 @@ def lfilter(waveform, a_coeffs, b_coeffs):
shape = waveform.size() shape = waveform.size()
waveform = waveform.view(-1, shape[-1]) waveform = waveform.view(-1, shape[-1])
assert(a_coeffs.size(0) == b_coeffs.size(0)) assert (a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2) assert (len(waveform.size()) == 2)
assert(waveform.device == a_coeffs.device) assert (waveform.device == a_coeffs.device)
assert(b_coeffs.device == a_coeffs.device) assert (b_coeffs.device == a_coeffs.device)
device = waveform.device device = waveform.device
dtype = waveform.dtype dtype = waveform.dtype
n_channel, n_sample = waveform.size() n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(0) n_order = a_coeffs.size(0)
n_sample_padded = n_sample + n_order - 1 n_sample_padded = n_sample + n_order - 1
assert(n_order > 0) assert (n_order > 0)
# Pad the input and create output # Pad the input and create output
padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device) padded_waveform = torch.zeros(n_channel, n_sample_padded, dtype=dtype, device=device)
...@@ -720,13 +763,20 @@ def lfilter(waveform, a_coeffs, b_coeffs): ...@@ -720,13 +763,20 @@ def lfilter(waveform, a_coeffs, b_coeffs):
return output return output
def biquad(waveform, b0, b1, b2, a0, a1, a2): def biquad(
# type: (Tensor, float, float, float, float, float, float) -> Tensor waveform: Tensor,
b0: float,
b1: float,
b2: float,
a0: float,
a1: float,
a2: float
) -> Tensor:
r"""Perform a biquad filter of input tensor. Initial conditions set to 0. r"""Perform a biquad filter of input tensor. Initial conditions set to 0.
https://en.wikipedia.org/wiki/Digital_biquad_filter https://en.wikipedia.org/wiki/Digital_biquad_filter
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
b0 (float): numerator coefficient of current input, x[n] b0 (float): numerator coefficient of current input, x[n]
b1 (float): numerator coefficient of input one time step ago x[n-1] b1 (float): numerator coefficient of input one time step ago x[n-1]
b2 (float): numerator coefficient of input two time steps ago x[n-2] b2 (float): numerator coefficient of input two time steps ago x[n-2]
...@@ -735,7 +785,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): ...@@ -735,7 +785,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
a2 (float): denominator coefficient of current output y[n-2] a2 (float): denominator coefficient of current output y[n-2]
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform with dimension of `(..., time)`
""" """
device = waveform.device device = waveform.device
...@@ -749,23 +799,26 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): ...@@ -749,23 +799,26 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2):
return output_waveform return output_waveform
def _dB2Linear(x): def _dB2Linear(x: float) -> float:
# type: (float) -> float
return math.exp(x * math.log(10) / 20.0) return math.exp(x * math.log(10) / 20.0)
def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def highpass_biquad(
# type: (Tensor, int, float, float) -> Tensor waveform: Tensor,
sample_rate: int,
cutoff_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform dimension of `(..., time)`
""" """
GAIN = 1. GAIN = 1.
...@@ -783,18 +836,22 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -783,18 +836,22 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): def lowpass_biquad(
# type: (Tensor, int, float, float) -> Tensor waveform: Tensor,
sample_rate: int,
cutoff_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)` waveform (torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
cutoff_freq (float): filter cutoff frequency cutoff_freq (float): filter cutoff frequency
Q (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
""" """
GAIN = 1. GAIN = 1.
...@@ -812,18 +869,22 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): ...@@ -812,18 +869,22 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def allpass_biquad(waveform, sample_rate, central_freq, Q=0.707): def allpass_biquad(
# type: (Tensor, int, float, float) -> Tensor waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design two-pole all-pass filter. Similar to SoX implementation. r"""Design two-pole all-pass filter. Similar to SoX implementation.
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform(torch.Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
References: References:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
...@@ -841,20 +902,25 @@ def allpass_biquad(waveform, sample_rate, central_freq, Q=0.707): ...@@ -841,20 +902,25 @@ def allpass_biquad(waveform, sample_rate, central_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def bandpass_biquad(waveform, sample_rate, central_freq, Q=0.707, const_skirt_gain=False): def bandpass_biquad(
# type: (Tensor, int, float, float, bool) -> Tensor waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707,
const_skirt_gain: bool = False
) -> Tensor:
r"""Design two-pole band-pass filter. Similar to SoX implementation. r"""Design two-pole band-pass filter. Similar to SoX implementation.
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
const_skirt_gain (bool) : If ``True``, uses a constant skirt gain (peak gain = Q). const_skirt_gain (bool, optional) : If ``True``, uses a constant skirt gain (peak gain = Q).
If ``False``, uses a constant 0dB peak gain. (Default: ``False``) If ``False``, uses a constant 0dB peak gain. (Default: ``False``)
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
References: References:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
...@@ -873,18 +939,22 @@ def bandpass_biquad(waveform, sample_rate, central_freq, Q=0.707, const_skirt_ga ...@@ -873,18 +939,22 @@ def bandpass_biquad(waveform, sample_rate, central_freq, Q=0.707, const_skirt_ga
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def bandreject_biquad(waveform, sample_rate, central_freq, Q=0.707): def bandreject_biquad(
# type: (Tensor, int, float, float) -> Tensor waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707
) -> Tensor:
r"""Design two-pole band-reject filter. Similar to SoX implementation. r"""Design two-pole band-reject filter. Similar to SoX implementation.
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
References: References:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
...@@ -902,19 +972,24 @@ def bandreject_biquad(waveform, sample_rate, central_freq, Q=0.707): ...@@ -902,19 +972,24 @@ def bandreject_biquad(waveform, sample_rate, central_freq, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): def equalizer_biquad(
# type: (Tensor, int, float, float, float) -> Tensor waveform: Tensor,
sample_rate: int,
center_freq: float,
gain: float,
Q: float = 0.707
) -> Tensor:
r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation. r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation.
Args: Args:
waveform (torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
center_freq (float): filter's central frequency center_freq (float): filter's central frequency
gain (float): desired gain at the boost (or attenuation) in dB gain (float): desired gain at the boost (or attenuation) in dB
q_factor (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``)
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
""" """
w0 = 2 * math.pi * center_freq / sample_rate w0 = 2 * math.pi * center_freq / sample_rate
A = math.exp(gain / 40.0 * math.log(10)) A = math.exp(gain / 40.0 * math.log(10))
...@@ -929,21 +1004,26 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): ...@@ -929,21 +1004,26 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def band_biquad(waveform, sample_rate, central_freq, Q=0.707, noise=False): def band_biquad(
# type: (Tensor, int, float, float, bool) -> Tensor waveform: Tensor,
sample_rate: int,
central_freq: float,
Q: float = 0.707,
noise: bool = False
) -> Tensor:
r"""Design two-pole band filter. Similar to SoX implementation. r"""Design two-pole band filter. Similar to SoX implementation.
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
central_freq (float): central frequency (in Hz) central_freq (float): central frequency (in Hz)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
noise (bool) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion). noise (bool, optional) : If ``True``, uses the alternate mode for un-pitched audio (e.g. percussion).
If ``False``, uses mode oriented to pitched audio, i.e. voice, singing, If ``False``, uses mode oriented to pitched audio, i.e. voice, singing,
or instrumental music. (Default: ``False``) or instrumental music (Default: ``False``).
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
References: References:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
...@@ -969,19 +1049,24 @@ def band_biquad(waveform, sample_rate, central_freq, Q=0.707, noise=False): ...@@ -969,19 +1049,24 @@ def band_biquad(waveform, sample_rate, central_freq, Q=0.707, noise=False):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def treble_biquad(waveform, sample_rate, gain, central_freq=3000, Q=0.707): def treble_biquad(
# type: (Tensor, int, float, float, float) -> Tensor waveform: Tensor,
sample_rate: int,
gain: float,
central_freq: float = 3000,
Q: float = 0.707
) -> Tensor:
r"""Design a treble tone-control effect. Similar to SoX implementation. r"""Design a treble tone-control effect. Similar to SoX implementation.
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz)
gain (float): desired gain at the boost (or attenuation) in dB. gain (float): desired gain at the boost (or attenuation) in dB.
central_freq (float): central frequency (in Hz). (Default: ``3000``) central_freq (float, optional): central frequency (in Hz). (Default: ``3000``)
q_factor (float): https://en.wikipedia.org/wiki/Q_factor Q (float, optional): https://en.wikipedia.org/wiki/Q_factor (Default: ``0.707``).
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
References: References:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
...@@ -1005,16 +1090,18 @@ def treble_biquad(waveform, sample_rate, gain, central_freq=3000, Q=0.707): ...@@ -1005,16 +1090,18 @@ def treble_biquad(waveform, sample_rate, gain, central_freq=3000, Q=0.707):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def deemph_biquad(waveform, sample_rate): def deemph_biquad(
# type: (Tensor, int) -> Tensor waveform: Tensor,
sample_rate: int
) -> Tensor:
r"""Apply ISO 908 CD de-emphasis (shelving) IIR filter. Similar to SoX implementation. r"""Apply ISO 908 CD de-emphasis (shelving) IIR filter. Similar to SoX implementation.
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, Allowed sample rate ``44100`` or ``48000`` sample_rate (int): sampling rate of the waveform, Allowed sample rate ``44100`` or ``48000``
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
References: References:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
...@@ -1050,17 +1137,19 @@ def deemph_biquad(waveform, sample_rate): ...@@ -1050,17 +1137,19 @@ def deemph_biquad(waveform, sample_rate):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def riaa_biquad(waveform, sample_rate): def riaa_biquad(
# type: (Tensor, int) -> Tensor waveform: Tensor,
sample_rate: int
) -> Tensor:
r"""Apply RIAA vinyl playback equalisation. Similar to SoX implementation. r"""Apply RIAA vinyl playback equalisation. Similar to SoX implementation.
Args: Args:
waveform(torch.Tensor): audio waveform of dimension of `(..., time)` waveform (Tensor): audio waveform of dimension of `(..., time)`
sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz). sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz).
Allowed sample rates in Hz : ``44100``,``48000``,``88200``,``96000`` Allowed sample rates in Hz : ``44100``,``48000``,``88200``,``96000``
Returns: Returns:
output_waveform (torch.Tensor): Dimension of `(..., time)` Tensor: Waveform of dimension of `(..., time)`
References: References:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
...@@ -1102,7 +1191,7 @@ def riaa_biquad(waveform, sample_rate): ...@@ -1102,7 +1191,7 @@ def riaa_biquad(waveform, sample_rate):
a_re = a0 + a1 * math.cos(-y) + a2 * math.cos(-2 * y) a_re = a0 + a1 * math.cos(-y) + a2 * math.cos(-2 * y)
b_im = b1 * math.sin(-y) + b2 * math.sin(-2 * y) b_im = b1 * math.sin(-y) + b2 * math.sin(-2 * y)
a_im = a1 * math.sin(-y) + a2 * math.sin(-2 * y) a_im = a1 * math.sin(-y) + a2 * math.sin(-2 * y)
g = 1 / math.sqrt((b_re**2 + b_im**2) / (a_re**2 + a_im**2)) g = 1 / math.sqrt((b_re ** 2 + b_im ** 2) / (a_re ** 2 + a_im ** 2))
b0 *= g b0 *= g
b1 *= g b1 *= g
...@@ -1111,8 +1200,12 @@ def riaa_biquad(waveform, sample_rate): ...@@ -1111,8 +1200,12 @@ def riaa_biquad(waveform, sample_rate):
return biquad(waveform, b0, b1, b2, a0, a1, a2) return biquad(waveform, b0, b1, b2, a0, a1, a2)
def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): def mask_along_axis_iid(
# type: (Tensor, int, float, int) -> Tensor specgrams: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
r""" r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
...@@ -1125,7 +1218,7 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): ...@@ -1125,7 +1218,7 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns: Returns:
torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
""" """
if axis != 2 and axis != 3: if axis != 2 and axis != 3:
...@@ -1147,8 +1240,12 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): ...@@ -1147,8 +1240,12 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis):
return specgrams return specgrams
def mask_along_axis(specgram, mask_param, mask_value, axis): def mask_along_axis(
# type: (Tensor, int, float, int) -> Tensor specgram: Tensor,
mask_param: int,
mask_value: float,
axis: int
) -> Tensor:
r""" r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
...@@ -1161,7 +1258,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): ...@@ -1161,7 +1258,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns: Returns:
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) Tensor: Masked spectrogram of dimensions (channel, freq, time)
""" """
# pack batch # pack batch
...@@ -1188,8 +1285,11 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): ...@@ -1188,8 +1285,11 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
return specgram return specgram
def compute_deltas(specgram, win_length=5, mode="replicate"): def compute_deltas(
# type: (Tensor, int, str) -> Tensor specgram: Tensor,
win_length: int = 5,
mode: str = "replicate"
) -> Tensor:
r"""Compute delta coefficients of a tensor, usually a spectrogram: r"""Compute delta coefficients of a tensor, usually a spectrogram:
.. math:: .. math::
...@@ -1200,12 +1300,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -1200,12 +1300,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
:math:`N` is (`win_length`-1)//2. :math:`N` is (`win_length`-1)//2.
Args: Args:
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time) specgram (Tensor): Tensor of audio of dimension (..., freq, time)
win_length (int): The window length used for computing delta win_length (int, optional): The window length used for computing delta (Default: ``5``)
mode (str): Mode parameter passed to padding mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Returns: Returns:
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time) Tensor: Tensor of deltas of dimension (..., freq, time)
Example Example
>>> specgram = torch.randn(1, 40, 1000) >>> specgram = torch.randn(1, 40, 1000)
...@@ -1226,11 +1326,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -1226,11 +1326,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode) specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
kernel = ( kernel = (torch.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype).repeat(specgram.shape[1], 1, 1))
torch
.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
.repeat(specgram.shape[1], 1, 1)
)
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
...@@ -1240,16 +1336,18 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -1240,16 +1336,18 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
return output return output
def gain(waveform, gain_db=1.0): def gain(
# type: (Tensor, float) -> Tensor waveform: Tensor,
gain_db: float = 1.0
) -> Tensor:
r"""Apply amplification or attenuation to the whole waveform. r"""Apply amplification or attenuation to the whole waveform.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension (..., time).
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`). gain_db (float, optional) Gain adjustment in decibels (dB) (Default: ``1.0``).
Returns: Returns:
torch.Tensor: the whole waveform amplified by gain_db. Tensor: the whole waveform amplified by gain_db.
""" """
if (gain_db == 0): if (gain_db == 0):
return waveform return waveform
...@@ -1259,7 +1357,10 @@ def gain(waveform, gain_db=1.0): ...@@ -1259,7 +1357,10 @@ def gain(waveform, gain_db=1.0):
return waveform * ratio return waveform * ratio
def _add_noise_shaping(dithered_waveform, waveform): def _add_noise_shaping(
dithered_waveform: Tensor,
waveform: Tensor
) -> Tensor:
r"""Noise shaping is calculated by error: r"""Noise shaping is calculated by error:
error[n] = dithered[n] - original[n] error[n] = dithered[n] - original[n]
noise_shaped_waveform[n] = dithered[n] + error[n-1] noise_shaped_waveform[n] = dithered[n] + error[n-1]
...@@ -1281,8 +1382,10 @@ def _add_noise_shaping(dithered_waveform, waveform): ...@@ -1281,8 +1382,10 @@ def _add_noise_shaping(dithered_waveform, waveform):
return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:]) return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:])
def _apply_probability_distribution(waveform, density_function="TPDF"): def _apply_probability_distribution(
# type: (Tensor, str) -> Tensor waveform: Tensor,
density_function: str = "TPDF"
) -> Tensor:
r"""Apply a probability distribution function on a waveform. r"""Apply a probability distribution function on a waveform.
Triangular probability density function (TPDF) dither noise has a Triangular probability density function (TPDF) dither noise has a
...@@ -1297,14 +1400,14 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): ...@@ -1297,14 +1400,14 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
The relationship of probabilities of results follows a bell-shaped, The relationship of probabilities of results follows a bell-shaped,
or Gaussian curve, typical of dither generated by analog sources. or Gaussian curve, typical of dither generated by analog sources.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time) waveform (Tensor): Tensor of audio of dimension (..., time)
probability_density_function (string): The density function of a probability_density_function (str, optional): The density function of a
continuous random variable (Default: `TPDF`) continuous random variable (Default: ``"TPDF"``)
Options: Triangular Probability Density Function - `TPDF` Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF` Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF` Gaussian Probability Density Function - `GPDF`
Returns: Returns:
torch.Tensor: waveform dithered with TPDF Tensor: waveform dithered with TPDF
""" """
# pack batch # pack batch
...@@ -1353,23 +1456,25 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): ...@@ -1353,23 +1456,25 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:]) return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:])
def dither(waveform, density_function="TPDF", noise_shaping=False): def dither(
# type: (Tensor, str, bool) -> Tensor waveform: Tensor,
density_function: str = "TPDF",
noise_shaping: bool = False
) -> Tensor:
r"""Dither increases the perceived dynamic range of audio stored at a r"""Dither increases the perceived dynamic range of audio stored at a
particular bit-depth by eliminating nonlinear truncation distortion particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization). (i.e. adding minimally perceived noise to mask distortion caused by quantization).
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time) waveform (Tensor): Tensor of audio of dimension (..., time)
density_function (string): The density function of a density_function (str, optional): The density function of a continuous random variable (Default: ``"TPDF"``)
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF` Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF` Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF` Gaussian Probability Density Function - `GPDF`
noise_shaping (boolean): a filtering process that shapes the spectral noise_shaping (bool, optional): a filtering process that shapes the spectral
energy of quantisation error (Default: `False`) energy of quantisation error (Default: ``False``)
Returns: Returns:
torch.Tensor: waveform dithered Tensor: waveform dithered
""" """
dithered = _apply_probability_distribution(waveform, density_function=density_function) dithered = _apply_probability_distribution(waveform, density_function=density_function)
...@@ -1379,8 +1484,12 @@ def dither(waveform, density_function="TPDF", noise_shaping=False): ...@@ -1379,8 +1484,12 @@ def dither(waveform, density_function="TPDF", noise_shaping=False):
return dithered return dithered
def _compute_nccf(waveform, sample_rate, frame_time, freq_low): def _compute_nccf(
# type: (Tensor, int, float, int) -> Tensor waveform: Tensor,
sample_rate: int,
frame_time: float,
freq_low: int
) -> Tensor:
r""" r"""
Compute Normalized Cross-Correlation Function (NCCF). Compute Normalized Cross-Correlation Function (NCCF).
...@@ -1390,7 +1499,7 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low): ...@@ -1390,7 +1499,7 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
where where
:math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`, :math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
:math:`w` is the waveform, :math:`w` is the waveform,
:math:`N` is the lenght of a frame, :math:`N` is the length of a frame,
:math:`b_i` is the beginning of frame :math:`i`, :math:`b_i` is the beginning of frame :math:`i`,
:math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`. :math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
""" """
...@@ -1411,12 +1520,8 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low): ...@@ -1411,12 +1520,8 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# Compute lags # Compute lags
output_lag = [] output_lag = []
for lag in range(1, lags + 1): for lag in range(1, lags + 1):
s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[ s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
..., :num_of_frames, : s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[..., :num_of_frames, :]
]
s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[
..., :num_of_frames, :
]
output_frames = ( output_frames = (
(s1 * s2).sum(-1) (s1 * s2).sum(-1)
...@@ -1431,8 +1536,11 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low): ...@@ -1431,8 +1536,11 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
return nccf return nccf
def _combine_max(a, b, thresh=0.99): def _combine_max(
# type: (Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], float) -> Tuple[Tensor, Tensor] a: Tuple[Tensor, Tensor],
b: Tuple[Tensor, Tensor],
thresh: float = 0.99
) -> Tuple[Tensor, Tensor]:
""" """
Take value from first if bigger than a multiplicative factor of the second, elementwise. Take value from first if bigger than a multiplicative factor of the second, elementwise.
""" """
...@@ -1442,8 +1550,11 @@ def _combine_max(a, b, thresh=0.99): ...@@ -1442,8 +1550,11 @@ def _combine_max(a, b, thresh=0.99):
return values, indices return values, indices
def _find_max_per_frame(nccf, sample_rate, freq_high): def _find_max_per_frame(
# type: (Tensor, int, int) -> Tensor nccf: Tensor,
sample_rate: int,
freq_high: int
) -> Tensor:
r""" r"""
For each frame, take the highest value of NCCF, For each frame, take the highest value of NCCF,
apply centered median smoothing, and convert to frequency. apply centered median smoothing, and convert to frequency.
...@@ -1472,8 +1583,10 @@ def _find_max_per_frame(nccf, sample_rate, freq_high): ...@@ -1472,8 +1583,10 @@ def _find_max_per_frame(nccf, sample_rate, freq_high):
return indices return indices
def _median_smoothing(indices, win_length): def _median_smoothing(
# type: (Tensor, int) -> Tensor indices: Tensor,
win_length: int
) -> Tensor:
r""" r"""
Apply median smoothing to the 1D tensor over the given window. Apply median smoothing to the 1D tensor over the given window.
""" """
...@@ -1494,31 +1607,28 @@ def _median_smoothing(indices, win_length): ...@@ -1494,31 +1607,28 @@ def _median_smoothing(indices, win_length):
def detect_pitch_frequency( def detect_pitch_frequency(
waveform, waveform: Tensor,
sample_rate, sample_rate: int,
frame_time=10 ** (-2), frame_time: float = 10 ** (-2),
win_length=30, win_length: int = 30,
freq_low=85, freq_low: int = 85,
freq_high=3400, freq_high: int = 3400,
): ) -> Tensor:
# type: (Tensor, int, float, int, int, int) -> Tensor
r"""Detect pitch frequency. r"""Detect pitch frequency.
It is implemented using normalized cross-correlation function and median smoothing. It is implemented using normalized cross-correlation function and median smoothing.
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time) waveform (Tensor): Tensor of audio of dimension (..., freq, time)
sample_rate (int): The sample rate of the waveform (Hz) sample_rate (int): The sample rate of the waveform (Hz)
win_length (int): The window length for median smoothing (in number of frames) frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
freq_low (int): Lowest frequency that can be detected (Hz) win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
freq_high (int): Highest frequency that can be detected (Hz) freq_low (int, optional): Lowest frequency that can be detected (Hz) (Default: ``85``).
freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Returns: Returns:
freq (torch.Tensor): Tensor of audio of dimension (..., frame) Tensor: Tensor of freq of dimension (..., frame)
""" """
dim = waveform.dim()
# pack batch # pack batch
shape = list(waveform.size()) shape = list(waveform.size())
waveform = waveform.view([-1] + shape[-1:]) waveform = waveform.view([-1] + shape[-1:])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment