Unverified Commit 4936c9eb authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Improve Docstrings in transfroms (#442)

* get typing on Docstrings right

* Improve Documentation standardise
parent f1a5503e
......@@ -5,8 +5,8 @@ from warnings import warn
import math
import torch
from typing import Optional
from . import functional as F
from .compliance import kaldi
from torchaudio import functional as F
from torchaudio.compliance import kaldi
__all__ = [
......@@ -28,20 +28,20 @@ __all__ = [
class Spectrogram(torch.nn.Module):
r"""Create a spectrogram from a audio signal
r"""Create a spectrogram from a audio signal.
Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
win_length (int): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. (
Default: ``win_length // 2``)
pad (int): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
If None, then the complex spectrum is returned instead. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
......@@ -63,7 +63,7 @@ class Spectrogram(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Dimension (..., freq, time), where freq is
......@@ -92,22 +92,21 @@ class GriffinLim(torch.nn.Module):
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
n_iter (int, optional): Number of iteration for phase recovery process.
win_length (int): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. (
Default: ``win_length // 2``)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float): Exponent for the magnitude spectrogram,
power (float, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
momentum (float): The momentum parameter for fast Griffin-Lim.
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
momentum (float, optional): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
length (int, optional): Array length of the expected output. (Default: ``None``)
rand_init (bool): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
"""
__constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
'length', 'momentum', 'rand_init']
......@@ -145,7 +144,7 @@ class AmplitudeToDB(torch.nn.Module):
a full clip.
Args:
stype (str): scale of input tensor ('power' or 'magnitude'). The
stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. (Default: ``'power'``)
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80. (Default: ``None``)
......@@ -164,14 +163,14 @@ class AmplitudeToDB(torch.nn.Module):
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
def forward(self, x):
r"""Numerically stable implementation from Librosa
r"""Numerically stable implementation from Librosa.
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args:
x (torch.Tensor): Input tensor before being converted to decibel scale
x (torch.Tensor): Input tensor before being converted to decibel scale.
Returns:
torch.Tensor: Output tensor in decibel scale
torch.Tensor: Output tensor in decibel scale.
"""
return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
......@@ -183,12 +182,12 @@ class MelScale(torch.nn.Module):
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
n_mels (int): Number of mel filterbanks. (Default: ``128``)
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
f_min (float): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`.
if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
"""
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
......@@ -208,10 +207,10 @@ class MelScale(torch.nn.Module):
def forward(self, specgram):
r"""
Args:
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time)
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
# pack batch
......@@ -328,18 +327,17 @@ class MelSpectrogram(torch.nn.Module):
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
win_length (int): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. (
Default: ``win_length // 2``)
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
f_min (float): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``None``)
pad (int): Two sided padding of signal. (Default: ``0``)
n_mels (int): Number of mel filterbanks. (Default: ``128``)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``None``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
......@@ -367,10 +365,10 @@ class MelSpectrogram(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
"""
specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram)
......@@ -378,7 +376,7 @@ class MelSpectrogram(torch.nn.Module):
class MFCC(torch.nn.Module):
r"""Create the Mel-frequency cepstrum coefficients from an audio signal
r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
......@@ -389,12 +387,11 @@ class MFCC(torch.nn.Module):
a full clip.
Args:
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``)
dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``)
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default:
``False``)
log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
"""
__constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
......@@ -426,10 +423,10 @@ class MFCC(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time)
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
"""
# pack batch
......@@ -460,7 +457,7 @@ class MuLawEncoding(torch.nn.Module):
returns a signal encoded with values from 0 to quantization_channels - 1
Args:
quantization_channels (int): Number of channels (Default: ``256``)
quantization_channels (int, optional): Number of channels. (Default: ``256``)
"""
__constants__ = ['quantization_channels']
......@@ -471,10 +468,10 @@ class MuLawEncoding(torch.nn.Module):
def forward(self, x):
r"""
Args:
x (torch.Tensor): A signal to be encoded
x (torch.Tensor): A signal to be encoded.
Returns:
x_mu (torch.Tensor): An encoded signal
x_mu (torch.Tensor): An encoded signal.
"""
return F.mu_law_encoding(x, self.quantization_channels)
......@@ -487,7 +484,7 @@ class MuLawDecoding(torch.nn.Module):
and returns a signal scaled between -1 and 1.
Args:
quantization_channels (int): Number of channels (Default: ``256``)
quantization_channels (int, optional): Number of channels. (Default: ``256``)
"""
__constants__ = ['quantization_channels']
......@@ -498,23 +495,23 @@ class MuLawDecoding(torch.nn.Module):
def forward(self, x_mu):
r"""
Args:
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded.
Returns:
torch.Tensor: The signal decoded
torch.Tensor: The signal decoded.
"""
return F.mu_law_decoding(x_mu, self.quantization_channels)
class Resample(torch.nn.Module):
r"""Resample a signal from one frequency to another. A resampling method can
be given.
r"""Resample a signal from one frequency to another. A resampling method can be given.
Args:
orig_freq (float): The original frequency of the signal. (Default: ``16000``)
new_freq (float): The desired frequency. (Default: ``16000``)
resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``)
orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
"""
def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
super(Resample, self).__init__()
self.orig_freq = orig_freq
......@@ -524,10 +521,10 @@ class Resample(torch.nn.Module):
def forward(self, waveform):
r"""
Args:
waveform (torch.Tensor): The input signal of dimension (..., time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns:
torch.Tensor: Output signal of dimension (..., time)
torch.Tensor: Output signal of dimension (..., time).
"""
if self.resampling_method == 'sinc_interpolation':
......@@ -546,9 +543,10 @@ class Resample(torch.nn.Module):
class ComplexNorm(torch.nn.Module):
r"""Compute the norm of complex tensor input
r"""Compute the norm of complex tensor input.
Args:
power (float): Power of the norm. Defaults to `1.0`.
power (float, optional): Power of the norm. (Default: to ``1.0``)
"""
__constants__ = ['power']
......@@ -559,9 +557,10 @@ class ComplexNorm(torch.nn.Module):
def forward(self, complex_tensor):
r"""
Args:
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.
Returns:
Tensor: norm of the input tensor, shape of `(..., )`
Tensor: norm of the input tensor, shape of `(..., )`.
"""
return F.complex_norm(complex_tensor, self.power)
......@@ -572,7 +571,8 @@ class ComputeDeltas(torch.nn.Module):
See `torchaudio.functional.compute_deltas` for more details.
Args:
win_length (int): The window length used for computing delta.
win_length (int): The window length used for computing delta. (Default: ``5``)
mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
"""
__constants__ = ['win_length']
......@@ -584,10 +584,10 @@ class ComputeDeltas(torch.nn.Module):
def forward(self, specgram):
r"""
Args:
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time).
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time).
"""
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
......@@ -596,9 +596,9 @@ class TimeStretch(torch.nn.Module):
r"""Stretch stft in time without modifying pitch for a given rate.
Args:
hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by.
fixed_rate (float or None, optional): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
"""
__constants__ = ['fixed_rate']
......@@ -616,12 +616,12 @@ class TimeStretch(torch.nn.Module):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
Returns:
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
......@@ -643,9 +643,9 @@ class _AxisMasking(torch.nn.Module):
r"""Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
mask_param (int): Maximum possible length of the mask.
axis (int): What dimension the mask is applied on.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']
......@@ -660,10 +660,11 @@ class _AxisMasking(torch.nn.Module):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (..., freq, time)
specgram (torch.Tensor): Tensor of dimension (..., freq, time).
mask_value (float): Value to assign to the masked columns.
Returns:
torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
torch.Tensor: Masked spectrogram of dimensions (..., freq, time).
"""
# if iid_masks flag marked and specgram has a batch dimension
......@@ -679,8 +680,8 @@ class FrequencyMasking(_AxisMasking):
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. (Default: False)
iid_masks (bool, optional): weather to apply the same mask to all
the examples/channels in the batch. (Default: ``False``)
"""
def __init__(self, freq_mask_param, iid_masks=False):
......@@ -693,8 +694,8 @@ class TimeMasking(_AxisMasking):
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all
the examples/channels in the batch. Defaults to False.
iid_masks (bool, optional): weather to apply the same mask to all
the examples/channels in the batch. (Default: ``False``)
"""
def __init__(self, time_mask_param, iid_masks=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment