Commit 0902494e authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

torch.functional Docs (#140)

parent c569b40f
...@@ -68,6 +68,7 @@ instance/ ...@@ -68,6 +68,7 @@ instance/
# Sphinx documentation # Sphinx documentation
docs/_build/ docs/_build/
docs/src/
# PyBuilder # PyBuilder
target/ target/
......
...@@ -208,6 +208,7 @@ texinfo_documents = [ ...@@ -208,6 +208,7 @@ texinfo_documents = [
intersphinx_mapping = { intersphinx_mapping = {
'python': ('https://docs.python.org/', None), 'python': ('https://docs.python.org/', None),
'numpy': ('https://docs.scipy.org/doc/numpy/', None), 'numpy': ('https://docs.scipy.org/doc/numpy/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
} }
# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- # -- A patch that prevents Sphinx from cross-referencing ivar tags -------
......
.. role:: hidden
:class: hidden-section
torchaudio.functional
======================
.. currentmodule:: torchaudio.functional
Functions to perform common audio operations.
:hidden:`scale`
~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: scale
:hidden:`pad_trim`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: pad_trim
:hidden:`downmix_mono`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: downmix_mono
:hidden:`LC2CL`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: LC2CL
:hidden:`istft`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: istft
:hidden:`spectrogram`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: spectrogram
:hidden:`create_fb_matrix`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: create_fb_matrix
:hidden:`spectrogram_to_DB`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: spectrogram_to_DB
:hidden:`create_dct`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: create_dct
:hidden:`BLC2CBL`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: BLC2CBL
:hidden:`mu_law_encoding`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: mu_law_encoding
:hidden:`mu_law_expanding`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: mu_law_expanding
...@@ -12,6 +12,7 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio ...@@ -12,6 +12,7 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
compliance.kaldi compliance.kaldi
kaldi_io kaldi_io
transforms transforms
functional
legacy legacy
.. automodule:: torchaudio .. automodule:: torchaudio
......
...@@ -21,16 +21,17 @@ __all__ = [ ...@@ -21,16 +21,17 @@ __all__ = [
@torch.jit.script @torch.jit.script
def scale(tensor, factor): def scale(tensor, factor):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor) r"""Scale audio tensor from a 16-bit integer (represented as a
to a floating point number between -1.0 and 1.0. Note the 16-bit number is :class:`torch.FloatTensor`) to a floating point number between -1.0 and 1.0.
called the "bit depth" or "precision", not to be confused with "bit rate". Note the 16-bit number is called the "bit depth" or "precision", not to be
confused with "bit rate".
Inputs: Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels) tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n)
factor (int): Maximum value of input tensor factor (int): Maximum value of input tensor
Outputs: Returns:
Tensor: Scaled by the scale factor torch.Tensor: Scaled by the scale factor
""" """
if not tensor.is_floating_point(): if not tensor.is_floating_point():
tensor = tensor.to(torch.float32) tensor = tensor.to(torch.float32)
...@@ -41,17 +42,17 @@ def scale(tensor, factor): ...@@ -41,17 +42,17 @@ def scale(tensor, factor):
@torch.jit.script @torch.jit.script
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor # type: (Tensor, int, int, int, float) -> Tensor
"""Pad/Trim a 2d-Tensor (Signal or Labels) r"""Pad/trim a 2D tensor (signal or labels).
Inputs: Args:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n) tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n)
ch_dim (int): Dimension of channel (not size) ch_dim (int): Dimension of channel (not size)
max_len (int): Length to which the tensor will be padded max_len (int): Length to which the tensor will be padded
len_dim (int): Dimension of length (not size) len_dim (int): Dimension of length (not size)
fill_value (float): Value to fill in fill_value (float): Value to fill in
Outputs: Returns:
Tensor: Padded/trimmed tensor torch.Tensor: Padded/trimmed tensor
""" """
if max_len > tensor.size(len_dim): if max_len > tensor.size(len_dim):
# array of [padding_left, padding_right, padding_top, padding_bottom] # array of [padding_left, padding_right, padding_top, padding_bottom]
...@@ -71,14 +72,14 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): ...@@ -71,14 +72,14 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
@torch.jit.script @torch.jit.script
def downmix_mono(tensor, ch_dim): def downmix_mono(tensor, ch_dim):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Downmix any stereo signals to mono. r"""Downmix any stereo signals to mono.
Inputs: Args:
tensor (Tensor): Tensor of audio of size (c x n) or (n x c) tensor (torch.Tensor): Tensor of audio of size (c, n) or (n, c)
ch_dim (int): Dimension of channel (not size) ch_dim (int): Dimension of channel (not size)
Outputs: Returns:
Tensor: Mono signal torch.Tensor: Mono signal
""" """
if not tensor.is_floating_point(): if not tensor.is_floating_point():
tensor = tensor.to(torch.float32) tensor = tensor.to(torch.float32)
...@@ -90,13 +91,13 @@ def downmix_mono(tensor, ch_dim): ...@@ -90,13 +91,13 @@ def downmix_mono(tensor, ch_dim):
@torch.jit.script @torch.jit.script
def LC2CL(tensor): def LC2CL(tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
"""Permute a 2d tensor from samples (n x c) to (c x n) r"""Permute a 2D tensor from samples (n, c) to (c, n).
Inputs: Args:
tensor (Tensor): Tensor of audio signal with shape (LxC) tensor (torch.Tensor): Tensor of audio signal with shape (n, c)
Outputs: Returns:
Tensor: Tensor of audio signal with shape (CxL) torch.Tensor: Tensor of audio signal with shape (c, n)
""" """
return tensor.transpose(0, 1).contiguous() return tensor.transpose(0, 1).contiguous()
...@@ -119,41 +120,54 @@ def istft(stft_matrix, # type: Tensor ...@@ -119,41 +120,54 @@ def istft(stft_matrix, # type: Tensor
): ):
# type: (...) -> Tensor # type: (...) -> Tensor
r""" Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. r""" Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft.
It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the It has the same parameters (+ additional optional parameter of ``length``) and it should return the
least squares estimation of the original signal. The algorithm will check using the NOLA condition ( least squares estimation of the original signal. The algorithm will check using the NOLA condition (
nonzero overlap). nonzero overlap).
Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelop
Important consideration in the parameters ``window`` and ``center`` so that the envelop
created by the summation of all the windows is never zero at certain point in time. Specifically, created by the summation of all the windows is never zero at certain point in time. Specifically,
:math:`\sum_{t=-\ infty}^{\ infty} w^2[n-t\times hop\_length] \neq 0`. :math:`\sum_{t=-\infty}^{\infty} w^2[n-t\times hop\_length] \cancel{=} 0`.
Since stft discards elements at the end of the signal if they do not fit in a frame, the Since stft discards elements at the end of the signal if they do not fit in a frame, the
istft may return a shorter signal than the original signal (can occur if :attr:`center` is False istft may return a shorter signal than the original signal (can occur if `center` is False
since the signal isn't padded). since the signal isn't padded).
If :attr:`center` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding
If ``center`` is True, then there will be padding e.g. 'constant', 'reflect', etc. Left padding
can be trimmed off exactly because they can be calculated but right padding cannot be calculated can be trimmed off exactly because they can be calculated but right padding cannot be calculated
without additional information. without additional information.
Example: Suppose the last window is: Example: Suppose the last window is:
[17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0] [17, 18, 0, 0, 0] vs [18, 0, 0, 0, 0]
The n_frames, hop_length, win_length are all the same which prevents the calculation of right padding. The n_frames, hop_length, win_length are all the same which prevents the calculation of right padding.
These additional values could be zeros or a reflection of the signal so providing :attr:`length` These additional values could be zeros or a reflection of the signal so providing ``length``
could be useful. If :attr:`length` is None then padding will be aggressively removed (some loss of signal). could be useful. If ``length`` is ``None`` then padding will be aggressively removed
(some loss of signal).
[1] D. W. Griffin and J. S. Lim, “Signal estimation from modified short-time Fourier transform,” [1] D. W. Griffin and J. S. Lim, “Signal estimation from modified short-time Fourier transform,”
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Inputs:
stft_matrix (Tensor): output of stft where each row of a batch is a frequency and each column is Args:
a window. it has a shape of either (batch, fft_size, n_frames, 2) or (fft_size, n_frames, 2) stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each
n_fft (int): size of Fourier transform column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or (
hop_length (Optional[int]): the distance between neighboring sliding window frames. (Default: win_length // 4) fft_size, n_frames, 2)
win_length (Optional[int]): the size of window frame and STFT filter. (Default: n_fft) n_fft (int): Size of Fourier transform
window (Optional[Tensor]): the optional window function. (Default: torch.ones(win_length)) hop_length (Optional[int]): The distance between neighboring sliding window frames.
center (bool): whether :attr:`input` was padded on both sides so (Default: ``win_length // 4``)
win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
window (Optional[torch.Tensor]): The optional window function.
(Default: ``torch.ones(win_length)``)
center (bool): Whether ``input`` was padded on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}` that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`
pad_mode (str): controls the padding method used when :attr:`center` is ``True`` pad_mode (str): Controls the padding method used when ``center`` is ``True``
normalized (bool): whether the STFT was normalized normalized (bool): Whether the STFT was normalized
onesided (bool): whether the STFT is onesided onesided (bool): Whether the STFT is onesided
length (Optional[int]): the amount to trim the signal by (i.e. the length (Optional[int]): The amount to trim the signal by (i.e. the
original signal length). (Default: whole signal) original signal length). (Default: whole signal)
Outputs:
Tensor: least squares estimation of the original signal of size (batch, signal_length) or (signal_length) Returns:
torch.Tensor: Least squares estimation of the original signal of size
(batch, signal_length) or (signal_length)
""" """
stft_matrix_dim = stft_matrix.dim() stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim))
...@@ -241,22 +255,21 @@ def istft(stft_matrix, # type: Tensor ...@@ -241,22 +255,21 @@ def istft(stft_matrix, # type: Tensor
@torch.jit.script @torch.jit.script
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
"""Create a spectrogram from a raw audio signal r"""Create a spectrogram from a raw audio signal.
Inputs: Args:
sig (Tensor): Tensor of audio of size (c, n) sig (torch.Tensor): Tensor of audio of size (c, n)
pad (int): two sided padding of signal pad (int): Two sided padding of signal
window (Tensor): window_tensor window (torch.Tensor): Window_tensor
n_fft (int): size of fft n_fft (int): Size of fft
hop (int): length of hop between STFT windows hop (int): Length of hop between STFT windows
ws (int): window size ws (int): Window size
power (int > 0 ) : Exponent for the magnitude spectrogram, power (int) : Exponent for the magnitude spectrogram,
e.g., 1 for energy, 2 for power, etc. (must be > 0) e.g., 1 for energy, 2 for power, etc.
normalize (bool) : whether to normalize by magnitude after stft normalize (bool) : Whether to normalize by magnitude after stft
Returns:
Outputs: torch.Tensor: Channels x hops x n_fft (c, l, f), where channels
Tensor: channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided number of fourier bins, which should be the window size divided
by 2 plus 1. by 2 plus 1.
...@@ -280,17 +293,16 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): ...@@ -280,17 +293,16 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
@torch.jit.script @torch.jit.script
def create_fb_matrix(n_stft, f_min, f_max, n_mels): def create_fb_matrix(n_stft, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor # type: (int, float, float, int) -> Tensor
""" Create a frequency bin conversion matrix. r""" Create a frequency bin conversion matrix.
Inputs: Args:
n_stft (int): number of filter banks from spectrogram n_stft (int): Number of filter banks from spectrogram
f_min (float): minimum frequency f_min (float): Minimum frequency
f_max (float): maximum frequency f_max (float): Maximum frequency
n_mels (int): number of mel bins n_mels (int): Number of mel bins
Outputs:
Tensor: triangular filter banks (fb matrix)
Returns:
torch.Tensor: Triangular filter banks (fb matrix)
""" """
# get stft freq bins # get stft freq bins
stft_freqs = torch.linspace(f_min, f_max, n_stft) stft_freqs = torch.linspace(f_min, f_max, n_stft)
...@@ -315,22 +327,22 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels): ...@@ -315,22 +327,22 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
@torch.jit.script @torch.jit.script
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
"""Turns a spectrogram from the power/amplitude scale to the decibel scale. r"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a may return different values for an audio clip split into snippets vs. a
a full clip. a full clip.
Inputs: Args:
spec (Tensor): normal STFT spec (torch.Tensor): Normal STFT
multiplier (float): use 10. for power and 20. for amplitude multiplier (float): Use 10. for power and 20. for amplitude
amin (float): number to clamp spec amin (float): Number to clamp spec
db_multiplier (float): log10(max(reference value and amin)) db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): minimum negative cut-off in decibels. A reasonable number top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
is 80. is 80.
Outputs: Returns:
Tensor: spectrogram in DB torch.Tensor: Spectrogram in DB
""" """
spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin)) spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin))
spec_db -= multiplier * db_multiplier spec_db -= multiplier * db_multiplier
...@@ -345,17 +357,16 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None): ...@@ -345,17 +357,16 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
@torch.jit.script @torch.jit.script
def create_dct(n_mfcc, n_mels, norm): def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor # type: (int, int, Optional[str]) -> Tensor
""" r"""Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
Creates a DCT transformation matrix with shape (num_mels, num_mfcc), normalized depending on norm.
normalized depending on norm
Inputs: Args:
n_mfcc (int) : number of mfc coefficients to retain n_mfcc (int) : Number of mfc coefficients to retain
n_mels (int): number of MEL bins n_mels (int): Number of MEL bins
norm (Optional[str]) : norm to use (either 'ortho' or None) norm (Optional[str]) : Norm to use (either 'ortho' or None)
Outputs: Returns:
Tensor: The transformation matrix, to be right-multiplied to row-wise data. torch.Tensor: The transformation matrix, to be right-multiplied to row-wise data.
""" """
outdim = n_mfcc outdim = n_mfcc
dim = n_mels dim = n_mels
...@@ -375,14 +386,14 @@ def create_dct(n_mfcc, n_mels, norm): ...@@ -375,14 +386,14 @@ def create_dct(n_mfcc, n_mels, norm):
@torch.jit.script @torch.jit.script
def BLC2CBL(tensor): def BLC2CBL(tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x r"""Permute a 3D tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length Bands x Samples length.
Inputs: Args:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC) tensor (torch.Tensor): Tensor of spectrogram with shape (b, l, c)
Outputs: Returns:
Tensor: Tensor of spectrogram with shape (CxBxL) torch.Tensor: Tensor of spectrogram with shape (c, b, l)
""" """
return tensor.permute(2, 0, 1).contiguous() return tensor.permute(2, 0, 1).contiguous()
...@@ -390,18 +401,18 @@ def BLC2CBL(tensor): ...@@ -390,18 +401,18 @@ def BLC2CBL(tensor):
@torch.jit.script @torch.jit.script
def mu_law_encoding(x, qc): def mu_law_encoding(x, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Encode signal based on mu-law companding. For more info see the r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and This algorithm assumes the signal has been scaled to between -1 and 1 and
returns a signal encoded with values from 0 to quantization_channels - 1 returns a signal encoded with values from 0 to quantization_channels - 1.
Inputs: Args:
x (Tensor): Input tensor x (torch.Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels) qc (int): Number of channels (i.e. quantization channels)
Outputs: Returns:
Tensor: Input after mu-law companding torch.Tensor: Input after mu-law companding
""" """
assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor' assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor'
mu = qc - 1. mu = qc - 1.
...@@ -417,18 +428,18 @@ def mu_law_encoding(x, qc): ...@@ -417,18 +428,18 @@ def mu_law_encoding(x, qc):
@torch.jit.script @torch.jit.script
def mu_law_expanding(x_mu, qc): def mu_law_expanding(x_mu, qc):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
"""Decode mu-law encoded signal. For more info see the r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1 This expects an input with values between 0 and quantization_channels - 1
and returns a signal scaled between -1 and 1. and returns a signal scaled between -1 and 1.
Inputs: Args:
x_mu (Tensor): Input tensor x_mu (torch.Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels) qc (int): Number of channels (i.e. quantization channels)
Outputs: Returns:
Tensor: Input after decoding torch.Tensor: Input after decoding
""" """
assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor' assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor'
mu = qc - 1. mu = qc - 1.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment