Unverified Commit 4936c9eb authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Improve Docstrings in transfroms (#442)

* get typing on Docstrings right

* Improve Documentation standardise
parent f1a5503e
...@@ -5,8 +5,8 @@ from warnings import warn ...@@ -5,8 +5,8 @@ from warnings import warn
import math import math
import torch import torch
from typing import Optional from typing import Optional
from . import functional as F from torchaudio import functional as F
from .compliance import kaldi from torchaudio.compliance import kaldi
__all__ = [ __all__ = [
...@@ -28,20 +28,20 @@ __all__ = [ ...@@ -28,20 +28,20 @@ __all__ = [
class Spectrogram(torch.nn.Module): class Spectrogram(torch.nn.Module):
r"""Create a spectrogram from a audio signal r"""Create a spectrogram from a audio signal.
Args: Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int): Window size. (Default: ``n_fft``) win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. ( hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
Default: ``win_length // 2``) pad (int, optional): Two sided padding of signal. (Default: ``0``)
pad (int): Two sided padding of signal. (Default: ``0``) window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float): Exponent for the magnitude spectrogram, power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) (must be > 0) e.g., 1 for energy, 2 for power, etc.
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) If None, then the complex spectrum is returned instead. (Default: ``2``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
""" """
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
...@@ -63,7 +63,7 @@ class Spectrogram(torch.nn.Module): ...@@ -63,7 +63,7 @@ class Spectrogram(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time) waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Dimension (..., freq, time), where freq is torch.Tensor: Dimension (..., freq, time), where freq is
...@@ -92,22 +92,21 @@ class GriffinLim(torch.nn.Module): ...@@ -92,22 +92,21 @@ class GriffinLim(torch.nn.Module):
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args: Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
n_iter (int, optional): Number of iteration for phase recovery process. n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
win_length (int): Window size. (Default: ``n_fft``) win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. ( hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
Default: ``win_length // 2``) window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float): Exponent for the magnitude spectrogram, power (float, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``) normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
momentum (float): The momentum parameter for fast Griffin-Lim. momentum (float, optional): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method. Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99) Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: ``0.99``)
length (int, optional): Array length of the expected output. (Default: ``None``) length (int, optional): Array length of the expected output. (Default: ``None``)
rand_init (bool): Initializes phase randomly if True and to zero otherwise. (Default: ``True``) rand_init (bool, optional): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
""" """
__constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized', __constants__ = ['n_fft', 'n_iter', 'win_length', 'hop_length', 'power', 'normalized',
'length', 'momentum', 'rand_init'] 'length', 'momentum', 'rand_init']
...@@ -145,7 +144,7 @@ class AmplitudeToDB(torch.nn.Module): ...@@ -145,7 +144,7 @@ class AmplitudeToDB(torch.nn.Module):
a full clip. a full clip.
Args: Args:
stype (str): scale of input tensor ('power' or 'magnitude'). The stype (str, optional): scale of input tensor ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. (Default: ``'power'``) power being the elementwise square of the magnitude. (Default: ``'power'``)
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80. (Default: ``None``) is 80. (Default: ``None``)
...@@ -164,14 +163,14 @@ class AmplitudeToDB(torch.nn.Module): ...@@ -164,14 +163,14 @@ class AmplitudeToDB(torch.nn.Module):
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) self.db_multiplier = math.log10(max(self.amin, self.ref_value))
def forward(self, x): def forward(self, x):
r"""Numerically stable implementation from Librosa r"""Numerically stable implementation from Librosa.
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args: Args:
x (torch.Tensor): Input tensor before being converted to decibel scale x (torch.Tensor): Input tensor before being converted to decibel scale.
Returns: Returns:
torch.Tensor: Output tensor in decibel scale torch.Tensor: Output tensor in decibel scale.
""" """
return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db) return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
...@@ -183,12 +182,12 @@ class MelScale(torch.nn.Module): ...@@ -183,12 +182,12 @@ class MelScale(torch.nn.Module):
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)). User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args: Args:
n_mels (int): Number of mel filterbanks. (Default: ``128``) n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
sample_rate (int): Sample rate of audio signal. (Default: ``16000``) sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
f_min (float): Minimum frequency. (Default: ``0.``) f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``) f_max (float or None, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`. if None is given. See ``n_fft`` in :class:`Spectrogram`. (Default: ``None``)
""" """
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max'] __constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
...@@ -208,10 +207,10 @@ class MelScale(torch.nn.Module): ...@@ -208,10 +207,10 @@ class MelScale(torch.nn.Module):
def forward(self, specgram): def forward(self, specgram):
r""" r"""
Args: Args:
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time) specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time).
Returns: Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time) torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
""" """
# pack batch # pack batch
...@@ -328,18 +327,17 @@ class MelSpectrogram(torch.nn.Module): ...@@ -328,18 +327,17 @@ class MelSpectrogram(torch.nn.Module):
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html * http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args: Args:
sample_rate (int): Sample rate of audio signal. (Default: ``16000``) sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
win_length (int): Window size. (Default: ``n_fft``) win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. ( hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
Default: ``win_length // 2``) n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_min (float): Minimum frequency. (Default: ``0.``) f_max (float or None, optional): Maximum frequency. (Default: ``None``)
f_max (float, optional): Maximum frequency. (Default: ``None``) pad (int, optional): Two sided padding of signal. (Default: ``0``)
pad (int): Two sided padding of signal. (Default: ``0``) n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
n_mels (int): Number of mel filterbanks. (Default: ``128``) window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``) wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
Example Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True) >>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...@@ -367,10 +365,10 @@ class MelSpectrogram(torch.nn.Module): ...@@ -367,10 +365,10 @@ class MelSpectrogram(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time) waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time) torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time).
""" """
specgram = self.spectrogram(waveform) specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram) mel_specgram = self.mel_scale(specgram)
...@@ -378,7 +376,7 @@ class MelSpectrogram(torch.nn.Module): ...@@ -378,7 +376,7 @@ class MelSpectrogram(torch.nn.Module):
class MFCC(torch.nn.Module): class MFCC(torch.nn.Module):
r"""Create the Mel-frequency cepstrum coefficients from an audio signal r"""Create the Mel-frequency cepstrum coefficients from an audio signal.
By default, this calculates the MFCC on the DB-scaled Mel spectrogram. By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to This is not the textbook implementation, but is implemented here to
...@@ -389,12 +387,11 @@ class MFCC(torch.nn.Module): ...@@ -389,12 +387,11 @@ class MFCC(torch.nn.Module):
a full clip. a full clip.
Args: Args:
sample_rate (int): Sample rate of audio signal. (Default: ``16000``) sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``) n_mfcc (int, optional): Number of mfc coefficients to retain. (Default: ``40``)
dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``) dct_type (int, optional): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``) norm (str, optional): norm to use. (Default: ``'ortho'``)
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default: log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
``False``)
melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``) melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
""" """
__constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels'] __constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
...@@ -426,10 +423,10 @@ class MFCC(torch.nn.Module): ...@@ -426,10 +423,10 @@ class MFCC(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time) waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time) torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
""" """
# pack batch # pack batch
...@@ -460,7 +457,7 @@ class MuLawEncoding(torch.nn.Module): ...@@ -460,7 +457,7 @@ class MuLawEncoding(torch.nn.Module):
returns a signal encoded with values from 0 to quantization_channels - 1 returns a signal encoded with values from 0 to quantization_channels - 1
Args: Args:
quantization_channels (int): Number of channels (Default: ``256``) quantization_channels (int, optional): Number of channels. (Default: ``256``)
""" """
__constants__ = ['quantization_channels'] __constants__ = ['quantization_channels']
...@@ -471,10 +468,10 @@ class MuLawEncoding(torch.nn.Module): ...@@ -471,10 +468,10 @@ class MuLawEncoding(torch.nn.Module):
def forward(self, x): def forward(self, x):
r""" r"""
Args: Args:
x (torch.Tensor): A signal to be encoded x (torch.Tensor): A signal to be encoded.
Returns: Returns:
x_mu (torch.Tensor): An encoded signal x_mu (torch.Tensor): An encoded signal.
""" """
return F.mu_law_encoding(x, self.quantization_channels) return F.mu_law_encoding(x, self.quantization_channels)
...@@ -487,7 +484,7 @@ class MuLawDecoding(torch.nn.Module): ...@@ -487,7 +484,7 @@ class MuLawDecoding(torch.nn.Module):
and returns a signal scaled between -1 and 1. and returns a signal scaled between -1 and 1.
Args: Args:
quantization_channels (int): Number of channels (Default: ``256``) quantization_channels (int, optional): Number of channels. (Default: ``256``)
""" """
__constants__ = ['quantization_channels'] __constants__ = ['quantization_channels']
...@@ -498,23 +495,23 @@ class MuLawDecoding(torch.nn.Module): ...@@ -498,23 +495,23 @@ class MuLawDecoding(torch.nn.Module):
def forward(self, x_mu): def forward(self, x_mu):
r""" r"""
Args: Args:
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded.
Returns: Returns:
torch.Tensor: The signal decoded torch.Tensor: The signal decoded.
""" """
return F.mu_law_decoding(x_mu, self.quantization_channels) return F.mu_law_decoding(x_mu, self.quantization_channels)
class Resample(torch.nn.Module): class Resample(torch.nn.Module):
r"""Resample a signal from one frequency to another. A resampling method can r"""Resample a signal from one frequency to another. A resampling method can be given.
be given.
Args: Args:
orig_freq (float): The original frequency of the signal. (Default: ``16000``) orig_freq (float, optional): The original frequency of the signal. (Default: ``16000``)
new_freq (float): The desired frequency. (Default: ``16000``) new_freq (float, optional): The desired frequency. (Default: ``16000``)
resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``) resampling_method (str, optional): The resampling method. (Default: ``'sinc_interpolation'``)
""" """
def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'): def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_interpolation'):
super(Resample, self).__init__() super(Resample, self).__init__()
self.orig_freq = orig_freq self.orig_freq = orig_freq
...@@ -524,10 +521,10 @@ class Resample(torch.nn.Module): ...@@ -524,10 +521,10 @@ class Resample(torch.nn.Module):
def forward(self, waveform): def forward(self, waveform):
r""" r"""
Args: Args:
waveform (torch.Tensor): The input signal of dimension (..., time) waveform (torch.Tensor): Tensor of audio of dimension (..., time).
Returns: Returns:
torch.Tensor: Output signal of dimension (..., time) torch.Tensor: Output signal of dimension (..., time).
""" """
if self.resampling_method == 'sinc_interpolation': if self.resampling_method == 'sinc_interpolation':
...@@ -546,9 +543,10 @@ class Resample(torch.nn.Module): ...@@ -546,9 +543,10 @@ class Resample(torch.nn.Module):
class ComplexNorm(torch.nn.Module): class ComplexNorm(torch.nn.Module):
r"""Compute the norm of complex tensor input r"""Compute the norm of complex tensor input.
Args: Args:
power (float): Power of the norm. Defaults to `1.0`. power (float, optional): Power of the norm. (Default: to ``1.0``)
""" """
__constants__ = ['power'] __constants__ = ['power']
...@@ -559,9 +557,10 @@ class ComplexNorm(torch.nn.Module): ...@@ -559,9 +557,10 @@ class ComplexNorm(torch.nn.Module):
def forward(self, complex_tensor): def forward(self, complex_tensor):
r""" r"""
Args: Args:
complex_tensor (Tensor): Tensor shape of `(..., complex=2)` complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.
Returns: Returns:
Tensor: norm of the input tensor, shape of `(..., )` Tensor: norm of the input tensor, shape of `(..., )`.
""" """
return F.complex_norm(complex_tensor, self.power) return F.complex_norm(complex_tensor, self.power)
...@@ -572,7 +571,8 @@ class ComputeDeltas(torch.nn.Module): ...@@ -572,7 +571,8 @@ class ComputeDeltas(torch.nn.Module):
See `torchaudio.functional.compute_deltas` for more details. See `torchaudio.functional.compute_deltas` for more details.
Args: Args:
win_length (int): The window length used for computing delta. win_length (int): The window length used for computing delta. (Default: ``5``)
mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
""" """
__constants__ = ['win_length'] __constants__ = ['win_length']
...@@ -584,10 +584,10 @@ class ComputeDeltas(torch.nn.Module): ...@@ -584,10 +584,10 @@ class ComputeDeltas(torch.nn.Module):
def forward(self, specgram): def forward(self, specgram):
r""" r"""
Args: Args:
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time) specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time).
Returns: Returns:
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time) deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time).
""" """
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)
...@@ -596,9 +596,9 @@ class TimeStretch(torch.nn.Module): ...@@ -596,9 +596,9 @@ class TimeStretch(torch.nn.Module):
r"""Stretch stft in time without modifying pitch for a given rate. r"""Stretch stft in time without modifying pitch for a given rate.
Args: Args:
hop_length (int): Number audio of frames between STFT columns. (Default: ``n_fft // 2``) hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``) n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by. fixed_rate (float or None, optional): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``) If None is provided, rate must be passed to the forward method. (Default: ``None``)
""" """
__constants__ = ['fixed_rate'] __constants__ = ['fixed_rate']
...@@ -616,12 +616,12 @@ class TimeStretch(torch.nn.Module): ...@@ -616,12 +616,12 @@ class TimeStretch(torch.nn.Module):
# type: (Tensor, Optional[float]) -> Tensor # type: (Tensor, Optional[float]) -> Tensor
r""" r"""
Args: Args:
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2) complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2).
overriding_rate (float or None): speed up to apply to this batch. overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate`` If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
Returns: Returns:
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2) (Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2).
""" """
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)" assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
...@@ -643,9 +643,9 @@ class _AxisMasking(torch.nn.Module): ...@@ -643,9 +643,9 @@ class _AxisMasking(torch.nn.Module):
r"""Apply masking to a spectrogram. r"""Apply masking to a spectrogram.
Args: Args:
mask_param (int): Maximum possible length of the mask mask_param (int): Maximum possible length of the mask.
axis: What dimension the mask is applied on axis (int): What dimension the mask is applied on.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
""" """
__constants__ = ['mask_param', 'axis', 'iid_masks'] __constants__ = ['mask_param', 'axis', 'iid_masks']
...@@ -660,10 +660,11 @@ class _AxisMasking(torch.nn.Module): ...@@ -660,10 +660,11 @@ class _AxisMasking(torch.nn.Module):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r""" r"""
Args: Args:
specgram (torch.Tensor): Tensor of dimension (..., freq, time) specgram (torch.Tensor): Tensor of dimension (..., freq, time).
mask_value (float): Value to assign to the masked columns.
Returns: Returns:
torch.Tensor: Masked spectrogram of dimensions (..., freq, time) torch.Tensor: Masked spectrogram of dimensions (..., freq, time).
""" """
# if iid_masks flag marked and specgram has a batch dimension # if iid_masks flag marked and specgram has a batch dimension
...@@ -679,8 +680,8 @@ class FrequencyMasking(_AxisMasking): ...@@ -679,8 +680,8 @@ class FrequencyMasking(_AxisMasking):
Args: Args:
freq_mask_param (int): maximum possible length of the mask. freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param). Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all iid_masks (bool, optional): weather to apply the same mask to all
the examples/channels in the batch. (Default: False) the examples/channels in the batch. (Default: ``False``)
""" """
def __init__(self, freq_mask_param, iid_masks=False): def __init__(self, freq_mask_param, iid_masks=False):
...@@ -693,8 +694,8 @@ class TimeMasking(_AxisMasking): ...@@ -693,8 +694,8 @@ class TimeMasking(_AxisMasking):
Args: Args:
time_mask_param (int): maximum possible length of the mask. time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param). Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all iid_masks (bool, optional): weather to apply the same mask to all
the examples/channels in the batch. Defaults to False. the examples/channels in the batch. (Default: ``False``)
""" """
def __init__(self, time_mask_param, iid_masks=False): def __init__(self, time_mask_param, iid_masks=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment