Commit 8f270d09 authored by Caroline Chen's avatar Caroline Chen
Browse files

Standardize tensor shapes format in docs (#1838)

parent dc0990c7
...@@ -650,20 +650,20 @@ def filtfilt( ...@@ -650,20 +650,20 @@ def filtfilt(
Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html
Args: Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``. Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary). Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``. Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary). Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns: Returns:
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or ``(..., time)`` otherwise. are 2D Tensors, or `(..., time)` otherwise.
""" """
forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True) forward_filtered = lfilter(waveform, a_coeffs, b_coeffs, clamp=False, batching=True)
backward_filtered = lfilter( backward_filtered = lfilter(
...@@ -970,13 +970,13 @@ def lfilter( ...@@ -970,13 +970,13 @@ def lfilter(
Using double precision could also minimize numerical precision errors. Using double precision could also minimize numerical precision errors.
Args: Args:
waveform (Tensor): audio waveform of dimension of ``(..., time)``. Must be normalized to -1 to 1. waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``. Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary). Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape ``(num_order + 1)`` or 2D with shape ``(num_filters, num_order + 1)``. 1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``. Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary). Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``) clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
...@@ -986,8 +986,8 @@ def lfilter( ...@@ -986,8 +986,8 @@ def lfilter(
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``) a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
Returns: Returns:
Tensor: Waveform with dimension of either ``(..., num_filters, time)`` if ``a_coeffs`` and ``b_coeffs`` Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or ``(..., time)`` otherwise. are 2D Tensors, or `(..., time)` otherwise.
""" """
assert a_coeffs.size() == b_coeffs.size() assert a_coeffs.size() == b_coeffs.size()
assert a_coeffs.ndim <= 2 assert a_coeffs.ndim <= 2
......
...@@ -62,7 +62,7 @@ def spectrogram( ...@@ -62,7 +62,7 @@ def spectrogram(
The spectrogram can be either magnitude-only or complex. The spectrogram can be either magnitude-only or complex.
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., time) waveform (Tensor): Tensor of audio of dimension `(..., time)`
pad (int): Two sided padding of signal pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT n_fft (int): Size of FFT
...@@ -89,7 +89,7 @@ def spectrogram( ...@@ -89,7 +89,7 @@ def spectrogram(
power spectrogram, which is a real-valued tensor. power spectrogram, which is a real-valued tensor.
Returns: Returns:
Tensor: Dimension (..., freq, time), freq is Tensor: Dimension `(..., freq, time)`, freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of ``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame). Fourier bins, and time is the number of window hops (n_frame).
""" """
...@@ -172,7 +172,7 @@ def inverse_spectrogram( ...@@ -172,7 +172,7 @@ def inverse_spectrogram(
Default: ``True`` Default: ``True``
Returns: Returns:
Tensor: Dimension (..., time). Least squares estimation of the original signal. Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
""" """
if spectrogram.dtype == torch.float32 or spectrogram.dtype == torch.float64: if spectrogram.dtype == torch.float32 or spectrogram.dtype == torch.float64:
...@@ -246,7 +246,7 @@ def griffinlim( ...@@ -246,7 +246,7 @@ def griffinlim(
and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`]. and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
Args: Args:
specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames) specgram (Tensor): A magnitude-only STFT spectrogram of dimension `(..., freq, frames)`
where freq is ``n_fft // 2 + 1``. where freq is ``n_fft // 2 + 1``.
window (Tensor): Window tensor that is applied/multiplied to each frame/window window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
...@@ -263,7 +263,7 @@ def griffinlim( ...@@ -263,7 +263,7 @@ def griffinlim(
rand_init (bool): Initializes phase randomly if True, to zero otherwise. rand_init (bool): Initializes phase randomly if True, to zero otherwise.
Returns: Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given. torch.Tensor: waveform of `(..., time)`, where time equals the ``length`` parameter if given.
""" """
assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum) assert momentum < 1, 'momentum={} > 1 can be unstable'.format(momentum)
assert momentum >= 0, 'momentum={} < 0'.format(momentum) assert momentum >= 0, 'momentum={} < 0'.format(momentum)
...@@ -791,10 +791,10 @@ def phase_vocoder( ...@@ -791,10 +791,10 @@ def phase_vocoder(
Args: Args:
complex_specgrams (Tensor): complex_specgrams (Tensor):
Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)`` Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype. or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
rate (float): Speed-up factor rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1) phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`
Returns: Returns:
Tensor: Tensor:
...@@ -907,13 +907,13 @@ def mask_along_axis_iid( ...@@ -907,13 +907,13 @@ def mask_along_axis_iid(
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
Args: Args:
specgrams (Tensor): Real spectrograms (batch, channel, freq, time) specgrams (Tensor): Real spectrograms `(batch, channel, freq, time)`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns: Returns:
Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) Tensor: Masked spectrograms of dimensions `(batch, channel, freq, time)`
""" """
if axis not in [2, 3]: if axis not in [2, 3]:
...@@ -950,13 +950,13 @@ def mask_along_axis( ...@@ -950,13 +950,13 @@ def mask_along_axis(
All examples will have the same mask interval. All examples will have the same mask interval.
Args: Args:
specgram (Tensor): Real spectrogram (channel, freq, time) specgram (Tensor): Real spectrogram `(channel, freq, time)`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns: Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time) Tensor: Masked spectrogram of dimensions `(channel, freq, time)`
""" """
if axis not in [1, 2]: if axis not in [1, 2]:
raise ValueError('Only Frequency and Time masking are supported') raise ValueError('Only Frequency and Time masking are supported')
...@@ -999,12 +999,12 @@ def compute_deltas( ...@@ -999,12 +999,12 @@ def compute_deltas(
:math:`N` is ``(win_length-1)//2``. :math:`N` is ``(win_length-1)//2``.
Args: Args:
specgram (Tensor): Tensor of audio of dimension (..., freq, time) specgram (Tensor): Tensor of audio of dimension `(..., freq, time)`
win_length (int, optional): The window length used for computing delta (Default: ``5``) win_length (int, optional): The window length used for computing delta (Default: ``5``)
mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``) mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Returns: Returns:
Tensor: Tensor of deltas of dimension (..., freq, time) Tensor: Tensor of deltas of dimension `(..., freq, time)`
Example Example
>>> specgram = torch.randn(1, 40, 1000) >>> specgram = torch.randn(1, 40, 1000)
...@@ -1172,7 +1172,7 @@ def detect_pitch_frequency( ...@@ -1172,7 +1172,7 @@ def detect_pitch_frequency(
It is implemented using normalized cross-correlation function and median smoothing. It is implemented using normalized cross-correlation function and median smoothing.
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time) waveform (Tensor): Tensor of audio of dimension `(..., freq, time)`
sample_rate (int): The sample rate of the waveform (Hz) sample_rate (int): The sample rate of the waveform (Hz)
frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``). frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``). win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
...@@ -1180,7 +1180,7 @@ def detect_pitch_frequency( ...@@ -1180,7 +1180,7 @@ def detect_pitch_frequency(
freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``). freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Returns: Returns:
Tensor: Tensor of freq of dimension (..., frame) Tensor: Tensor of freq of dimension `(..., frame)`
""" """
# pack batch # pack batch
shape = list(waveform.size()) shape = list(waveform.size())
...@@ -1211,7 +1211,7 @@ def sliding_window_cmn( ...@@ -1211,7 +1211,7 @@ def sliding_window_cmn(
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance. Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Args: Args:
specgram (Tensor): Tensor of audio of dimension (..., time, freq) specgram (Tensor): Tensor of audio of dimension `(..., time, freq)`
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600) cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start). min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100) Only applicable if center == false, ignored if center==true (int, default = 100)
...@@ -1220,7 +1220,7 @@ def sliding_window_cmn( ...@@ -1220,7 +1220,7 @@ def sliding_window_cmn(
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false) norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
Returns: Returns:
Tensor: Tensor matching input shape (..., freq, time) Tensor: Tensor matching input shape `(..., freq, time)`
""" """
input_shape = specgram.shape input_shape = specgram.shape
num_frames, num_feats = input_shape[-2:] num_frames, num_feats = input_shape[-2:]
...@@ -1307,7 +1307,7 @@ def spectral_centroid( ...@@ -1307,7 +1307,7 @@ def spectral_centroid(
frequency values, weighted by their magnitude. frequency values, weighted by their magnitude.
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., time) waveform (Tensor): Tensor of audio of dimension `(..., time)`
sample_rate (int): Sample rate of the audio waveform sample_rate (int): Sample rate of the audio waveform
pad (int): Two sided padding of signal pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window window (Tensor): Window tensor that is applied/multiplied to each frame/window
...@@ -1316,7 +1316,7 @@ def spectral_centroid( ...@@ -1316,7 +1316,7 @@ def spectral_centroid(
win_length (int): Window size win_length (int): Window size
Returns: Returns:
Tensor: Dimension (..., time) Tensor: Dimension `(..., time)`
""" """
specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length, specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, power=1., normalized=False) win_length=win_length, power=1., normalized=False)
...@@ -1344,8 +1344,8 @@ def apply_codec( ...@@ -1344,8 +1344,8 @@ def apply_codec(
sample_rate (int): Sample rate of the audio waveform. sample_rate (int): Sample rate of the audio waveform.
format (str): File format. format (str): File format.
channels_first (bool, optional): channels_first (bool, optional):
When True, both the input and output Tensor have dimension ``[channel, time]``. When True, both the input and output Tensor have dimension `(channel, time)`.
Otherwise, they have dimension ``[time, channel]``. Otherwise, they have dimension `(time, channel)`.
compression (float or None, optional): Used for formats other than WAV. compression (float or None, optional): Used for formats other than WAV.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`. For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
encoding (str or None, optional): Changes the encoding for the supported formats. encoding (str or None, optional): Changes the encoding for the supported formats.
...@@ -1355,7 +1355,7 @@ def apply_codec( ...@@ -1355,7 +1355,7 @@ def apply_codec(
Returns: Returns:
torch.Tensor: Resulting Tensor. torch.Tensor: Resulting Tensor.
If ``channels_first=True``, it has ``[channel, time]`` else ``[time, channel]``. If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
""" """
bytes = io.BytesIO() bytes = io.BytesIO()
torchaudio.backend.sox_io_backend.save(bytes, torchaudio.backend.sox_io_backend.save(bytes,
...@@ -1453,7 +1453,7 @@ def compute_kaldi_pitch( ...@@ -1453,7 +1453,7 @@ def compute_kaldi_pitch(
This makes different types of features give the same number of frames. (default: True) This makes different types of features give the same number of frames. (default: True)
Returns: Returns:
Tensor: Pitch feature. Shape: ``(batch, frames 2)`` where the last dimension Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
corresponds to pitch and NCCF. corresponds to pitch and NCCF.
""" """
shape = waveform.shape shape = waveform.shape
...@@ -1605,7 +1605,7 @@ def resample( ...@@ -1605,7 +1605,7 @@ def resample(
more efficient computation if resampling multiple waveforms with the same resampling parameters. more efficient computation if resampling multiple waveforms with the same resampling parameters.
Args: Args:
waveform (Tensor): The input signal of dimension (..., time) waveform (Tensor): The input signal of dimension `(..., time)`
orig_freq (float): The original frequency of the signal orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
...@@ -1617,7 +1617,7 @@ def resample( ...@@ -1617,7 +1617,7 @@ def resample(
beta (float or None, optional): The shape parameter used for kaiser window. beta (float or None, optional): The shape parameter used for kaiser window.
Returns: Returns:
Tensor: The waveform at the new frequency of dimension (..., time). Tensor: The waveform at the new frequency of dimension `(..., time).`
""" """
assert orig_freq > 0.0 and new_freq > 0.0 assert orig_freq > 0.0 and new_freq > 0.0
......
...@@ -301,9 +301,9 @@ class FeedForward(Module): ...@@ -301,9 +301,9 @@ class FeedForward(Module):
def forward(self, x): def forward(self, x):
""" """
Args: Args:
x (Tensor): shape: ``(batch, sequence_length, io_features)`` x (Tensor): shape: `(batch, sequence_length, io_features)`
Returns: Returns:
x (Tensor): shape: ``(batch, sequence_length, io_features)`` x (Tensor): shape: `(batch, sequence_length, io_features)`
""" """
x = self.intermediate_dense(x) x = self.intermediate_dense(x)
x = torch.nn.functional.gelu(x) x = torch.nn.functional.gelu(x)
...@@ -339,9 +339,9 @@ class EncoderLayer(Module): ...@@ -339,9 +339,9 @@ class EncoderLayer(Module):
): ):
""" """
Args: Args:
x (Tensor): shape: ``(batch, sequence_length, embed_dim)`` x (Tensor): shape: `(batch, sequence_length, embed_dim)`
attention_mask (Tensor or None, optional): attention_mask (Tensor or None, optional):
shape: ``(batch, 1, sequence_length, sequence_length)`` shape: `(batch, 1, sequence_length, sequence_length)`
""" """
residual = x residual = x
......
...@@ -48,10 +48,10 @@ class Wav2Vec2Model(Module): ...@@ -48,10 +48,10 @@ class Wav2Vec2Model(Module):
transformer block in encoder. transformer block in encoder.
Args: Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``. waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional): lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch. Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``. Shape: `(batch, )`.
num_layers (int or None, optional): num_layers (int or None, optional):
If given, limit the number of intermediate layers to go through. If given, limit the number of intermediate layers to go through.
Providing `1` will stop the computation after going through one Providing `1` will stop the computation after going through one
...@@ -62,9 +62,9 @@ class Wav2Vec2Model(Module): ...@@ -62,9 +62,9 @@ class Wav2Vec2Model(Module):
List of Tensors and an optional Tensor: List of Tensors and an optional Tensor:
List of Tensors List of Tensors
Features from requested layers. Features from requested layers.
Each Tensor is of shape: ``(batch, frames, feature dimention)`` Each Tensor is of shape: `(batch, frames, feature dimention)`
Tensor or None Tensor or None
If ``lengths`` argument was provided, a Tensor of shape ``(batch, )`` If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch. is retuned. It indicates the valid length of each feature in the batch.
""" """
x, lengths = self.feature_extractor(waveforms, lengths) x, lengths = self.feature_extractor(waveforms, lengths)
...@@ -79,18 +79,18 @@ class Wav2Vec2Model(Module): ...@@ -79,18 +79,18 @@ class Wav2Vec2Model(Module):
"""Compute the sequence of probability distribution over labels. """Compute the sequence of probability distribution over labels.
Args: Args:
waveforms (Tensor): Audio tensor of shape ``(batch, frames)``. waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional): lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch. Indicates the valid length of each audio sample in the batch.
Shape: ``(batch, )``. Shape: `(batch, )`.
Returns: Returns:
Tensor and an optional Tensor: Tensor and an optional Tensor:
Tensor Tensor
The sequences of probability distribution (in logit) over labels. The sequences of probability distribution (in logit) over labels.
Shape: ``(batch, frames, num labels)``. Shape: `(batch, frames, num labels)`.
Tensor or None Tensor or None
If ``lengths`` argument was provided, a Tensor of shape ``(batch, )`` If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch. is retuned. It indicates the valid length of each feature in the batch.
""" """
x, lengths = self.feature_extractor(waveforms, lengths) x, lengths = self.feature_extractor(waveforms, lengths)
......
...@@ -971,8 +971,8 @@ class TimeStretch(torch.nn.Module): ...@@ -971,8 +971,8 @@ class TimeStretch(torch.nn.Module):
r""" r"""
Args: Args:
complex_specgrams (Tensor): complex_specgrams (Tensor):
Either a real tensor of dimension of ``(..., freq, num_frame, complex=2)`` Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension ``(..., freq, num_frame)`` with complex dtype. or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch. overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
...@@ -1018,10 +1018,10 @@ class Fade(torch.nn.Module): ...@@ -1018,10 +1018,10 @@ class Fade(torch.nn.Module):
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension `(..., time)`.
Returns: Returns:
Tensor: Tensor of audio of dimension (..., time). Tensor: Tensor of audio of dimension `(..., time)`.
""" """
waveform_length = waveform.size()[-1] waveform_length = waveform.size()[-1]
device = waveform.device device = waveform.device
...@@ -1092,11 +1092,11 @@ class _AxisMasking(torch.nn.Module): ...@@ -1092,11 +1092,11 @@ class _AxisMasking(torch.nn.Module):
def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor: def forward(self, specgram: Tensor, mask_value: float = 0.) -> Tensor:
r""" r"""
Args: Args:
specgram (Tensor): Tensor of dimension (..., freq, time). specgram (Tensor): Tensor of dimension `(..., freq, time)`.
mask_value (float): Value to assign to the masked columns. mask_value (float): Value to assign to the masked columns.
Returns: Returns:
Tensor: Masked spectrogram of dimensions (..., freq, time). Tensor: Masked spectrogram of dimensions `(..., freq, time)`.
""" """
# if iid_masks flag marked and specgram has a batch dimension # if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4: if self.iid_masks and specgram.dim() == 4:
...@@ -1157,10 +1157,10 @@ class Vol(torch.nn.Module): ...@@ -1157,10 +1157,10 @@ class Vol(torch.nn.Module):
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension `(..., time)`.
Returns: Returns:
Tensor: Tensor of audio of dimension (..., time). Tensor: Tensor of audio of dimension `(..., time)`.
""" """
if self.gain_type == "amplitude": if self.gain_type == "amplitude":
waveform = waveform * self.gain waveform = waveform * self.gain
...@@ -1201,10 +1201,10 @@ class SlidingWindowCmn(torch.nn.Module): ...@@ -1201,10 +1201,10 @@ class SlidingWindowCmn(torch.nn.Module):
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension `(..., time)`.
Returns: Returns:
Tensor: Tensor of audio of dimension (..., time). Tensor: Tensor of audio of dimension `(..., time)`.
""" """
cmn_waveform = F.sliding_window_cmn( cmn_waveform = F.sliding_window_cmn(
waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars) waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
...@@ -1374,10 +1374,10 @@ class SpectralCentroid(torch.nn.Module): ...@@ -1374,10 +1374,10 @@ class SpectralCentroid(torch.nn.Module):
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension `(..., time)`.
Returns: Returns:
Tensor: Spectral Centroid of size (..., time). Tensor: Spectral Centroid of size `(..., time)`.
""" """
return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length, return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
...@@ -1428,7 +1428,7 @@ class PitchShift(torch.nn.Module): ...@@ -1428,7 +1428,7 @@ class PitchShift(torch.nn.Module):
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (Tensor): Tensor of audio of dimension (..., time). waveform (Tensor): Tensor of audio of dimension `(..., time)`.
Returns: Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`. Tensor: The pitch-shifted audio of shape `(..., time)`.
...@@ -1513,7 +1513,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch ...@@ -1513,7 +1513,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions. r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args: Args:
input (torch.Tensor): Tensor of dimension (..., channel, channel) input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1) (Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix dim2 (int, optional): the second dimension of the diagonal matrix
...@@ -1548,14 +1548,14 @@ class PSD(torch.nn.Module): ...@@ -1548,14 +1548,14 @@ class PSD(torch.nn.Module):
""" """
Args: Args:
specgram (torch.Tensor): multi-channel complex-valued STFT matrix. specgram (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension (..., channel, freq, time) Tensor of dimension `(..., channel, freq, time)`
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization. mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` or Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` or
of dimension (..., channel, freq, time) if multi_mask is ``True`` of dimension `(..., channel, freq, time)` if multi_mask is ``True``
Returns: Returns:
torch.Tensor: PSD matrix of the input STFT matrix. torch.Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension (..., freq, channel, channel) Tensor of dimension `(..., freq, channel, channel)`
""" """
# outer product: # outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2) # (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
...@@ -1804,11 +1804,11 @@ class MVDR(torch.nn.Module): ...@@ -1804,11 +1804,11 @@ class MVDR(torch.nn.Module):
Args: Args:
psd_s (torch.tensor): covariance matrix of speech psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel) Tensor of dimension `(..., freq, channel, channel)`
Returns: Returns:
torch.Tensor: the enhanced STFT torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1) Tensor of dimension `(..., freq, channel, 1)`
""" """
w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel) w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel)
_, indices = torch.max(w.abs(), dim=-1, keepdim=True) _, indices = torch.max(w.abs(), dim=-1, keepdim=True)
...@@ -1826,14 +1826,14 @@ class MVDR(torch.nn.Module): ...@@ -1826,14 +1826,14 @@ class MVDR(torch.nn.Module):
Args: Args:
psd_s (torch.tensor): covariance matrix of speech psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel) Tensor of dimension `(..., freq, channel, channel)`
psd_n (torch.Tensor): covariance matrix of noise psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension (..., freq, channel, channel) Tensor of dimension `(..., freq, channel, channel)`
reference_vector (torch.Tensor): one-hot reference channel matrix reference_vector (torch.Tensor): one-hot reference channel matrix
Returns: Returns:
torch.Tensor: the enhanced STFT torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1) Tensor of dimension `(..., freq, channel, 1)`
""" """
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector]) stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector])
...@@ -1850,13 +1850,13 @@ class MVDR(torch.nn.Module): ...@@ -1850,13 +1850,13 @@ class MVDR(torch.nn.Module):
r"""Apply the beamforming weight to the noisy STFT r"""Apply the beamforming weight to the noisy STFT
Args: Args:
specgram (torch.tensor): multi-channel noisy STFT specgram (torch.tensor): multi-channel noisy STFT
Tensor of dimension (..., channel, freq, time) Tensor of dimension `(..., channel, freq, time)`
beamform_vector (torch.Tensor): beamforming weight matrix beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension (..., freq, channel) Tensor of dimension `(..., freq, channel)`
Returns: Returns:
torch.Tensor: the enhanced STFT torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, time) Tensor of dimension `(..., freq, time)`
""" """
# (..., channel) x (..., channel, freq, time) -> (..., freq, time) # (..., channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram]) specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram])
...@@ -1897,18 +1897,18 @@ class MVDR(torch.nn.Module): ...@@ -1897,18 +1897,18 @@ class MVDR(torch.nn.Module):
Args: Args:
specgram (torch.Tensor): the multi-channel STF of the noisy speech. specgram (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension (..., channel, freq, time) Tensor of dimension `(..., channel, freq, time)`
mask_s (torch.Tensor): Time-Frequency mask of target speech. mask_s (torch.Tensor): Time-Frequency mask of target speech.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` Tensor of dimension `(..., freq, time)` if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True`` or or dimension `(..., channel, freq, time)` if multi_mask is ``True``
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise. mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` Tensor of dimension `(..., freq, time)` if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True`` or or dimension `(..., channel, freq, time)` if multi_mask is ``True``
(Default: None) (Default: None)
Returns: Returns:
torch.Tensor: The single-channel STFT of the enhanced speech. torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension (..., freq, time) Tensor of dimension `(..., freq, time)`
""" """
if specgram.ndim < 3: if specgram.ndim < 3:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment