Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
4936c9eb
Unverified
Commit
4936c9eb
authored
Mar 05, 2020
by
Vincent QB
Committed by
GitHub
Mar 05, 2020
Browse files
Improve Docstrings in transfroms (#442)
* get typing on Docstrings right * Improve Documentation standardise
parent
f1a5503e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
94 additions
and
93 deletions
+94
-93
torchaudio/transforms.py
torchaudio/transforms.py
+94
-93
No files found.
torchaudio/transforms.py
View file @
4936c9eb
...
@@ -5,8 +5,8 @@ from warnings import warn
...
@@ -5,8 +5,8 @@ from warnings import warn
import
math
import
math
import
torch
import
torch
from
typing
import
Optional
from
typing
import
Optional
from
.
import
functional
as
F
from
torchaudio
import
functional
as
F
from
.compliance
import
kaldi
from
torchaudio
.compliance
import
kaldi
__all__
=
[
__all__
=
[
...
@@ -28,20 +28,20 @@ __all__ = [
...
@@ -28,20 +28,20 @@ __all__ = [
class
Spectrogram
(
torch
.
nn
.
Module
):
class
Spectrogram
(
torch
.
nn
.
Module
):
r
"""Create a spectrogram from a audio signal
r
"""Create a spectrogram from a audio signal
.
Args:
Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int): Window size. (Default: ``n_fft``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. (
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
pad (int): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float): Exponent for the magnitude spectrogram,
power (float or None, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
(must be > 0) e.g., 1 for energy, 2 for power, etc.
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
If None, then the complex spectrum is returned instead. (Default: ``2``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
"""
"""
__constants__
=
[
'n_fft'
,
'win_length'
,
'hop_length'
,
'pad'
,
'power'
,
'normalized'
]
__constants__
=
[
'n_fft'
,
'win_length'
,
'hop_length'
,
'pad'
,
'power'
,
'normalized'
]
...
@@ -63,7 +63,7 @@ class Spectrogram(torch.nn.Module):
...
@@ -63,7 +63,7 @@ class Spectrogram(torch.nn.Module):
def
forward
(
self
,
waveform
):
def
forward
(
self
,
waveform
):
r
"""
r
"""
Args:
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
.
Returns:
Returns:
torch.Tensor: Dimension (..., freq, time), where freq is
torch.Tensor: Dimension (..., freq, time), where freq is
...
@@ -92,22 +92,21 @@ class GriffinLim(torch.nn.Module):
...
@@ -92,22 +92,21 @@ class GriffinLim(torch.nn.Module):
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args:
Args:
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
n_iter (int, optional): Number of iteration for phase recovery process.
n_iter (int, optional): Number of iteration for phase recovery process. (Default: ``32``)
win_length (int): Window size. (Default: ``n_fft``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. (
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
Default: ``win_length // 2``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float): Exponent for the magnitude spectrogram,
power (float
, optional
): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool): Whether to normalize by magnitude after stft. (Default: ``False``)
normalized (bool
, optional
): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
wkwargs (Dict[..., ...]
or None, optional
): Arguments for window function. (Default: ``None``)
momentum (float): The momentum parameter for fast Griffin-Lim.
momentum (float
, optional
): The momentum parameter for fast Griffin-Lim.
Setting this to 0 recovers the original Griffin-Lim method.
Setting this to 0 recovers the original Griffin-Lim method.
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default: 0.99)
Values near 1 can lead to faster convergence, but above 1 may not converge. (Default:
``
0.99
``
)
length (int, optional): Array length of the expected output. (Default: ``None``)
length (int, optional): Array length of the expected output. (Default: ``None``)
rand_init (bool): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
rand_init (bool
, optional
): Initializes phase randomly if True and to zero otherwise. (Default: ``True``)
"""
"""
__constants__
=
[
'n_fft'
,
'n_iter'
,
'win_length'
,
'hop_length'
,
'power'
,
'normalized'
,
__constants__
=
[
'n_fft'
,
'n_iter'
,
'win_length'
,
'hop_length'
,
'power'
,
'normalized'
,
'length'
,
'momentum'
,
'rand_init'
]
'length'
,
'momentum'
,
'rand_init'
]
...
@@ -145,7 +144,7 @@ class AmplitudeToDB(torch.nn.Module):
...
@@ -145,7 +144,7 @@ class AmplitudeToDB(torch.nn.Module):
a full clip.
a full clip.
Args:
Args:
stype (str): scale of input tensor ('power' or 'magnitude'). The
stype (str
, optional
): scale of input tensor ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. (Default: ``'power'``)
power being the elementwise square of the magnitude. (Default: ``'power'``)
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80. (Default: ``None``)
is 80. (Default: ``None``)
...
@@ -164,14 +163,14 @@ class AmplitudeToDB(torch.nn.Module):
...
@@ -164,14 +163,14 @@ class AmplitudeToDB(torch.nn.Module):
self
.
db_multiplier
=
math
.
log10
(
max
(
self
.
amin
,
self
.
ref_value
))
self
.
db_multiplier
=
math
.
log10
(
max
(
self
.
amin
,
self
.
ref_value
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
r
"""Numerically stable implementation from Librosa
r
"""Numerically stable implementation from Librosa
.
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args:
Args:
x (torch.Tensor): Input tensor before being converted to decibel scale
x (torch.Tensor): Input tensor before being converted to decibel scale
.
Returns:
Returns:
torch.Tensor: Output tensor in decibel scale
torch.Tensor: Output tensor in decibel scale
.
"""
"""
return
F
.
amplitude_to_DB
(
x
,
self
.
multiplier
,
self
.
amin
,
self
.
db_multiplier
,
self
.
top_db
)
return
F
.
amplitude_to_DB
(
x
,
self
.
multiplier
,
self
.
amin
,
self
.
db_multiplier
,
self
.
top_db
)
...
@@ -183,12 +182,12 @@ class MelScale(torch.nn.Module):
...
@@ -183,12 +182,12 @@ class MelScale(torch.nn.Module):
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
Args:
n_mels (int): Number of mel filterbanks. (Default: ``128``)
n_mels (int
, optional
): Number of mel filterbanks. (Default: ``128``)
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
sample_rate (int
, optional
): Sample rate of audio signal. (Default: ``16000``)
f_min (float): Minimum frequency. (Default: ``0.``)
f_min (float
, optional
): Minimum frequency. (Default: ``0.``)
f_max (float, optional): Maximum frequency. (Default: ``sample_rate // 2``)
f_max (float
or None
, optional): Maximum frequency. (Default: ``sample_rate // 2``)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if None is given. See ``n_fft`` in :class:`Spectrogram`.
if None is given. See ``n_fft`` in :class:`Spectrogram`.
(Default: ``None``)
"""
"""
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
__constants__
=
[
'n_mels'
,
'sample_rate'
,
'f_min'
,
'f_max'
]
...
@@ -208,10 +207,10 @@ class MelScale(torch.nn.Module):
...
@@ -208,10 +207,10 @@ class MelScale(torch.nn.Module):
def
forward
(
self
,
specgram
):
def
forward
(
self
,
specgram
):
r
"""
r
"""
Args:
Args:
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time)
specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time)
.
Returns:
Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
.
"""
"""
# pack batch
# pack batch
...
@@ -328,18 +327,17 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -328,18 +327,17 @@ class MelSpectrogram(torch.nn.Module):
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
Args:
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
sample_rate (int, optional): Sample rate of audio signal. (Default: ``16000``)
win_length (int): Window size. (Default: ``n_fft``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int, optional): Length of hop between STFT windows. (
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
Default: ``win_length // 2``)
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins
f_min (float, optional): Minimum frequency. (Default: ``0.``)
f_min (float): Minimum frequency. (Default: ``0.``)
f_max (float or None, optional): Maximum frequency. (Default: ``None``)
f_max (float, optional): Maximum frequency. (Default: ``None``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
pad (int): Two sided padding of signal. (Default: ``0``)
n_mels (int, optional): Number of mel filterbanks. (Default: ``128``)
n_mels (int): Number of mel filterbanks. (Default: ``128``)
window_fn (Callable[[...], torch.Tensor], optional): A function to create a window tensor
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: ``None``)
wkwargs (Dict[..., ...]
or None, optional
): Arguments for window function. (Default: ``None``)
Example
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
...
@@ -367,10 +365,10 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -367,10 +365,10 @@ class MelSpectrogram(torch.nn.Module):
def
forward
(
self
,
waveform
):
def
forward
(
self
,
waveform
):
r
"""
r
"""
Args:
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
.
Returns:
Returns:
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time)
.
"""
"""
specgram
=
self
.
spectrogram
(
waveform
)
specgram
=
self
.
spectrogram
(
waveform
)
mel_specgram
=
self
.
mel_scale
(
specgram
)
mel_specgram
=
self
.
mel_scale
(
specgram
)
...
@@ -378,7 +376,7 @@ class MelSpectrogram(torch.nn.Module):
...
@@ -378,7 +376,7 @@ class MelSpectrogram(torch.nn.Module):
class
MFCC
(
torch
.
nn
.
Module
):
class
MFCC
(
torch
.
nn
.
Module
):
r
"""Create the Mel-frequency cepstrum coefficients from an audio signal
r
"""Create the Mel-frequency cepstrum coefficients from an audio signal
.
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
This is not the textbook implementation, but is implemented here to
...
@@ -389,12 +387,11 @@ class MFCC(torch.nn.Module):
...
@@ -389,12 +387,11 @@ class MFCC(torch.nn.Module):
a full clip.
a full clip.
Args:
Args:
sample_rate (int): Sample rate of audio signal. (Default: ``16000``)
sample_rate (int
, optional
): Sample rate of audio signal. (Default: ``16000``)
n_mfcc (int): Number of mfc coefficients to retain. (Default: ``40``)
n_mfcc (int
, optional
): Number of mfc coefficients to retain. (Default: ``40``)
dct_type (int): type of DCT (discrete cosine transform) to use. (Default: ``2``)
dct_type (int
, optional
): type of DCT (discrete cosine transform) to use. (Default: ``2``)
norm (str, optional): norm to use. (Default: ``'ortho'``)
norm (str, optional): norm to use. (Default: ``'ortho'``)
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled. (Default:
log_mels (bool, optional): whether to use log-mel spectrograms instead of db-scaled. (Default: ``False``)
``False``)
melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
melkwargs (dict, optional): arguments for MelSpectrogram. (Default: ``None``)
"""
"""
__constants__
=
[
'sample_rate'
,
'n_mfcc'
,
'dct_type'
,
'top_db'
,
'log_mels'
]
__constants__
=
[
'sample_rate'
,
'n_mfcc'
,
'dct_type'
,
'top_db'
,
'log_mels'
]
...
@@ -426,10 +423,10 @@ class MFCC(torch.nn.Module):
...
@@ -426,10 +423,10 @@ class MFCC(torch.nn.Module):
def
forward
(
self
,
waveform
):
def
forward
(
self
,
waveform
):
r
"""
r
"""
Args:
Args:
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
waveform (torch.Tensor): Tensor of audio of dimension (..., time)
.
Returns:
Returns:
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time)
torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time)
.
"""
"""
# pack batch
# pack batch
...
@@ -460,7 +457,7 @@ class MuLawEncoding(torch.nn.Module):
...
@@ -460,7 +457,7 @@ class MuLawEncoding(torch.nn.Module):
returns a signal encoded with values from 0 to quantization_channels - 1
returns a signal encoded with values from 0 to quantization_channels - 1
Args:
Args:
quantization_channels (int): Number of channels (Default: ``256``)
quantization_channels (int
, optional
): Number of channels
.
(Default: ``256``)
"""
"""
__constants__
=
[
'quantization_channels'
]
__constants__
=
[
'quantization_channels'
]
...
@@ -471,10 +468,10 @@ class MuLawEncoding(torch.nn.Module):
...
@@ -471,10 +468,10 @@ class MuLawEncoding(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
r
"""
r
"""
Args:
Args:
x (torch.Tensor): A signal to be encoded
x (torch.Tensor): A signal to be encoded
.
Returns:
Returns:
x_mu (torch.Tensor): An encoded signal
x_mu (torch.Tensor): An encoded signal
.
"""
"""
return
F
.
mu_law_encoding
(
x
,
self
.
quantization_channels
)
return
F
.
mu_law_encoding
(
x
,
self
.
quantization_channels
)
...
@@ -487,7 +484,7 @@ class MuLawDecoding(torch.nn.Module):
...
@@ -487,7 +484,7 @@ class MuLawDecoding(torch.nn.Module):
and returns a signal scaled between -1 and 1.
and returns a signal scaled between -1 and 1.
Args:
Args:
quantization_channels (int): Number of channels (Default: ``256``)
quantization_channels (int
, optional
): Number of channels
.
(Default: ``256``)
"""
"""
__constants__
=
[
'quantization_channels'
]
__constants__
=
[
'quantization_channels'
]
...
@@ -498,23 +495,23 @@ class MuLawDecoding(torch.nn.Module):
...
@@ -498,23 +495,23 @@ class MuLawDecoding(torch.nn.Module):
def
forward
(
self
,
x_mu
):
def
forward
(
self
,
x_mu
):
r
"""
r
"""
Args:
Args:
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
.
Returns:
Returns:
torch.Tensor: The signal decoded
torch.Tensor: The signal decoded
.
"""
"""
return
F
.
mu_law_decoding
(
x_mu
,
self
.
quantization_channels
)
return
F
.
mu_law_decoding
(
x_mu
,
self
.
quantization_channels
)
class
Resample
(
torch
.
nn
.
Module
):
class
Resample
(
torch
.
nn
.
Module
):
r
"""Resample a signal from one frequency to another. A resampling method can
r
"""Resample a signal from one frequency to another. A resampling method can be given.
be given.
Args:
Args:
orig_freq (float): The original frequency of the signal. (Default: ``16000``)
orig_freq (float
, optional
): The original frequency of the signal. (Default: ``16000``)
new_freq (float): The desired frequency. (Default: ``16000``)
new_freq (float
, optional
): The desired frequency. (Default: ``16000``)
resampling_method (str): The resampling method (Default: ``'sinc_interpolation'``)
resampling_method (str
, optional
): The resampling method
.
(Default: ``'sinc_interpolation'``)
"""
"""
def
__init__
(
self
,
orig_freq
=
16000
,
new_freq
=
16000
,
resampling_method
=
'sinc_interpolation'
):
def
__init__
(
self
,
orig_freq
=
16000
,
new_freq
=
16000
,
resampling_method
=
'sinc_interpolation'
):
super
(
Resample
,
self
).
__init__
()
super
(
Resample
,
self
).
__init__
()
self
.
orig_freq
=
orig_freq
self
.
orig_freq
=
orig_freq
...
@@ -524,10 +521,10 @@ class Resample(torch.nn.Module):
...
@@ -524,10 +521,10 @@ class Resample(torch.nn.Module):
def
forward
(
self
,
waveform
):
def
forward
(
self
,
waveform
):
r
"""
r
"""
Args:
Args:
waveform (torch.Tensor): T
he input signal
of dimension (..., time)
waveform (torch.Tensor): T
ensor of audio
of dimension (..., time)
.
Returns:
Returns:
torch.Tensor: Output signal of dimension (..., time)
torch.Tensor: Output signal of dimension (..., time)
.
"""
"""
if
self
.
resampling_method
==
'sinc_interpolation'
:
if
self
.
resampling_method
==
'sinc_interpolation'
:
...
@@ -546,9 +543,10 @@ class Resample(torch.nn.Module):
...
@@ -546,9 +543,10 @@ class Resample(torch.nn.Module):
class
ComplexNorm
(
torch
.
nn
.
Module
):
class
ComplexNorm
(
torch
.
nn
.
Module
):
r
"""Compute the norm of complex tensor input
r
"""Compute the norm of complex tensor input.
Args:
Args:
power (float): Power of the norm. Default
s
to `1.0`
.
power (float
, optional
): Power of the norm.
(
Default
:
to
`
`1.0`
`)
"""
"""
__constants__
=
[
'power'
]
__constants__
=
[
'power'
]
...
@@ -559,9 +557,10 @@ class ComplexNorm(torch.nn.Module):
...
@@ -559,9 +557,10 @@ class ComplexNorm(torch.nn.Module):
def
forward
(
self
,
complex_tensor
):
def
forward
(
self
,
complex_tensor
):
r
"""
r
"""
Args:
Args:
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`.
Returns:
Returns:
Tensor: norm of the input tensor, shape of `(..., )`
Tensor: norm of the input tensor, shape of `(..., )`
.
"""
"""
return
F
.
complex_norm
(
complex_tensor
,
self
.
power
)
return
F
.
complex_norm
(
complex_tensor
,
self
.
power
)
...
@@ -572,7 +571,8 @@ class ComputeDeltas(torch.nn.Module):
...
@@ -572,7 +571,8 @@ class ComputeDeltas(torch.nn.Module):
See `torchaudio.functional.compute_deltas` for more details.
See `torchaudio.functional.compute_deltas` for more details.
Args:
Args:
win_length (int): The window length used for computing delta.
win_length (int): The window length used for computing delta. (Default: ``5``)
mode (str): Mode parameter passed to padding. (Default: ``'replicate'``)
"""
"""
__constants__
=
[
'win_length'
]
__constants__
=
[
'win_length'
]
...
@@ -584,10 +584,10 @@ class ComputeDeltas(torch.nn.Module):
...
@@ -584,10 +584,10 @@ class ComputeDeltas(torch.nn.Module):
def
forward
(
self
,
specgram
):
def
forward
(
self
,
specgram
):
r
"""
r
"""
Args:
Args:
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time)
.
Returns:
Returns:
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time)
.
"""
"""
return
F
.
compute_deltas
(
specgram
,
win_length
=
self
.
win_length
,
mode
=
self
.
mode
)
return
F
.
compute_deltas
(
specgram
,
win_length
=
self
.
win_length
,
mode
=
self
.
mode
)
...
@@ -596,9 +596,9 @@ class TimeStretch(torch.nn.Module):
...
@@ -596,9 +596,9 @@ class TimeStretch(torch.nn.Module):
r
"""Stretch stft in time without modifying pitch for a given rate.
r
"""Stretch stft in time without modifying pitch for a given rate.
Args:
Args:
hop_length (int
): Number audio of frames
between STFT
column
s. (Default: ``
n_fft
// 2``)
hop_length (int
or None, optional): Length of hop
between STFT
window
s. (Default: ``
win_length
// 2``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
fixed_rate (float): rate to speed up or slow down by.
fixed_rate (float
or None, optional
): rate to speed up or slow down by.
If None is provided, rate must be passed to the forward method. (Default: ``None``)
If None is provided, rate must be passed to the forward method. (Default: ``None``)
"""
"""
__constants__
=
[
'fixed_rate'
]
__constants__
=
[
'fixed_rate'
]
...
@@ -616,12 +616,12 @@ class TimeStretch(torch.nn.Module):
...
@@ -616,12 +616,12 @@ class TimeStretch(torch.nn.Module):
# type: (Tensor, Optional[float]) -> Tensor
# type: (Tensor, Optional[float]) -> Tensor
r
"""
r
"""
Args:
Args:
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
.
overriding_rate (float or None): speed up to apply to this batch.
overriding_rate (float or None
, optional
): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``
If no rate is passed, use ``self.fixed_rate``
. (Default: ``None``)
Returns:
Returns:
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
.
"""
"""
assert
complex_specgrams
.
size
(
-
1
)
==
2
,
"complex_specgrams should be a complex tensor, shape (..., complex=2)"
assert
complex_specgrams
.
size
(
-
1
)
==
2
,
"complex_specgrams should be a complex tensor, shape (..., complex=2)"
...
@@ -643,9 +643,9 @@ class _AxisMasking(torch.nn.Module):
...
@@ -643,9 +643,9 @@ class _AxisMasking(torch.nn.Module):
r
"""Apply masking to a spectrogram.
r
"""Apply masking to a spectrogram.
Args:
Args:
mask_param (int): Maximum possible length of the mask
mask_param (int): Maximum possible length of the mask
.
axis: What dimension the mask is applied on
axis
(int)
: What dimension the mask is applied on
.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension
.
"""
"""
__constants__
=
[
'mask_param'
,
'axis'
,
'iid_masks'
]
__constants__
=
[
'mask_param'
,
'axis'
,
'iid_masks'
]
...
@@ -660,10 +660,11 @@ class _AxisMasking(torch.nn.Module):
...
@@ -660,10 +660,11 @@ class _AxisMasking(torch.nn.Module):
# type: (Tensor, float) -> Tensor
# type: (Tensor, float) -> Tensor
r
"""
r
"""
Args:
Args:
specgram (torch.Tensor): Tensor of dimension (..., freq, time)
specgram (torch.Tensor): Tensor of dimension (..., freq, time).
mask_value (float): Value to assign to the masked columns.
Returns:
Returns:
torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
.
"""
"""
# if iid_masks flag marked and specgram has a batch dimension
# if iid_masks flag marked and specgram has a batch dimension
...
@@ -679,8 +680,8 @@ class FrequencyMasking(_AxisMasking):
...
@@ -679,8 +680,8 @@ class FrequencyMasking(_AxisMasking):
Args:
Args:
freq_mask_param (int): maximum possible length of the mask.
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool): weather to apply the same mask to all
iid_masks (bool
, optional
): weather to apply the same mask to all
the examples/channels in the batch. (Default: False)
the examples/channels in the batch. (Default:
``
False
``
)
"""
"""
def
__init__
(
self
,
freq_mask_param
,
iid_masks
=
False
):
def
__init__
(
self
,
freq_mask_param
,
iid_masks
=
False
):
...
@@ -693,8 +694,8 @@ class TimeMasking(_AxisMasking):
...
@@ -693,8 +694,8 @@ class TimeMasking(_AxisMasking):
Args:
Args:
time_mask_param (int): maximum possible length of the mask.
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool): weather to apply the same mask to all
iid_masks (bool
, optional
): weather to apply the same mask to all
the examples/channels in the batch. Default
s to
False
.
the examples/channels in the batch.
(
Default
: ``
False
``)
"""
"""
def
__init__
(
self
,
time_mask_param
,
iid_masks
=
False
):
def
__init__
(
self
,
time_mask_param
,
iid_masks
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment