Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
21a0d29e
Unverified
Commit
21a0d29e
authored
Oct 07, 2021
by
Caroline Chen
Committed by
GitHub
Oct 07, 2021
Browse files
Standardize tensor shapes format in docs (#1838)
parent
d857348f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
80 deletions
+80
-80
torchaudio/functional/filtering.py
torchaudio/functional/filtering.py
+10
-10
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+26
-26
torchaudio/models/wav2vec2/components.py
torchaudio/models/wav2vec2/components.py
+4
-4
torchaudio/models/wav2vec2/model.py
torchaudio/models/wav2vec2/model.py
+8
-8
torchaudio/transforms.py
torchaudio/transforms.py
+32
-32
No files found.
torchaudio/functional/filtering.py
View file @
21a0d29e
...
@@ -650,20 +650,20 @@ def filtfilt(
...
@@ -650,20 +650,20 @@ def filtfilt(
Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html
Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html
Args:
Args:
waveform (Tensor): audio waveform of dimension of
`
`(..., time)`
`
. Must be normalized to -1 to 1.
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Returns:
Tensor: Waveform with dimension of either
`
`(..., num_filters, time)`
`
if ``a_coeffs`` and ``b_coeffs``
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or
`
`(..., time)`
`
otherwise.
are 2D Tensors, or `(..., time)` otherwise.
"""
"""
forward_filtered
=
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
,
clamp
=
False
,
batching
=
True
)
forward_filtered
=
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
,
clamp
=
False
,
batching
=
True
)
backward_filtered
=
lfilter
(
backward_filtered
=
lfilter
(
...
@@ -970,13 +970,13 @@ def lfilter(
...
@@ -970,13 +970,13 @@ def lfilter(
Using double precision could also minimize numerical precision errors.
Using double precision could also minimize numerical precision errors.
Args:
Args:
waveform (Tensor): audio waveform of dimension of
`
`(..., time)`
`
. Must be normalized to -1 to 1.
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
...
@@ -986,8 +986,8 @@ def lfilter(
...
@@ -986,8 +986,8 @@ def lfilter(
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
Returns:
Returns:
Tensor: Waveform with dimension of either
`
`(..., num_filters, time)`
`
if ``a_coeffs`` and ``b_coeffs``
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or
`
`(..., time)`
`
otherwise.
are 2D Tensors, or `(..., time)` otherwise.
"""
"""
assert
a_coeffs
.
size
()
==
b_coeffs
.
size
()
assert
a_coeffs
.
size
()
==
b_coeffs
.
size
()
assert
a_coeffs
.
ndim
<=
2
assert
a_coeffs
.
ndim
<=
2
...
...
torchaudio/functional/functional.py
View file @
21a0d29e
...
@@ -62,7 +62,7 @@ def spectrogram(
...
@@ -62,7 +62,7 @@ def spectrogram(
The spectrogram can be either magnitude-only or complex.
The spectrogram can be either magnitude-only or complex.
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
pad (int): Two sided padding of signal
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
n_fft (int): Size of FFT
...
@@ -89,7 +89,7 @@ def spectrogram(
...
@@ -89,7 +89,7 @@ def spectrogram(
power spectrogram, which is a real-valued tensor.
power spectrogram, which is a real-valued tensor.
Returns:
Returns:
Tensor: Dimension (..., freq, time), freq is
Tensor: Dimension
`
(..., freq, time)
`
, freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
Fourier bins, and time is the number of window hops (n_frame).
"""
"""
...
@@ -172,7 +172,7 @@ def inverse_spectrogram(
...
@@ -172,7 +172,7 @@ def inverse_spectrogram(
Default: ``True``
Default: ``True``
Returns:
Returns:
Tensor: Dimension (..., time). Least squares estimation of the original signal.
Tensor: Dimension
`
(..., time)
`
. Least squares estimation of the original signal.
"""
"""
if
spectrogram
.
dtype
==
torch
.
float32
or
spectrogram
.
dtype
==
torch
.
float64
:
if
spectrogram
.
dtype
==
torch
.
float32
or
spectrogram
.
dtype
==
torch
.
float64
:
...
@@ -246,7 +246,7 @@ def griffinlim(
...
@@ -246,7 +246,7 @@ def griffinlim(
and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
Args:
Args:
specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
specgram (Tensor): A magnitude-only STFT spectrogram of dimension
`
(..., freq, frames)
`
where freq is ``n_fft // 2 + 1``.
where freq is ``n_fft // 2 + 1``.
window (Tensor): Window tensor that is applied/multiplied to each frame/window
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
...
@@ -263,7 +263,7 @@ def griffinlim(
...
@@ -263,7 +263,7 @@ def griffinlim(
rand_init (bool): Initializes phase randomly if True, to zero otherwise.
rand_init (bool): Initializes phase randomly if True, to zero otherwise.
Returns:
Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
torch.Tensor: waveform of
`
(..., time)
`
, where time equals the ``length`` parameter if given.
"""
"""
assert
momentum
<
1
,
'momentum={} > 1 can be unstable'
.
format
(
momentum
)
assert
momentum
<
1
,
'momentum={} > 1 can be unstable'
.
format
(
momentum
)
assert
momentum
>=
0
,
'momentum={} < 0'
.
format
(
momentum
)
assert
momentum
>=
0
,
'momentum={} < 0'
.
format
(
momentum
)
...
@@ -791,10 +791,10 @@ def phase_vocoder(
...
@@ -791,10 +791,10 @@ def phase_vocoder(
Args:
Args:
complex_specgrams (Tensor):
complex_specgrams (Tensor):
Either a real tensor of dimension of
`
`(..., freq, num_frame, complex=2)`
`
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension
`
`(..., freq, num_frame)`
`
with complex dtype.
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
rate (float): Speed-up factor
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
phase_advance (Tensor): Expected phase advance in each bin. Dimension of
`
(freq, 1)
`
Returns:
Returns:
Tensor:
Tensor:
...
@@ -907,13 +907,13 @@ def mask_along_axis_iid(
...
@@ -907,13 +907,13 @@ def mask_along_axis_iid(
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
Args:
Args:
specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
specgrams (Tensor): Real spectrograms
`
(batch, channel, freq, time)
`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns:
Returns:
Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
Tensor: Masked spectrograms of dimensions
`
(batch, channel, freq, time)
`
"""
"""
if
axis
not
in
[
2
,
3
]:
if
axis
not
in
[
2
,
3
]:
...
@@ -950,13 +950,13 @@ def mask_along_axis(
...
@@ -950,13 +950,13 @@ def mask_along_axis(
All examples will have the same mask interval.
All examples will have the same mask interval.
Args:
Args:
specgram (Tensor): Real spectrogram (channel, freq, time)
specgram (Tensor): Real spectrogram
`
(channel, freq, time)
`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns:
Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time)
Tensor: Masked spectrogram of dimensions
`
(channel, freq, time)
`
"""
"""
if
axis
not
in
[
1
,
2
]:
if
axis
not
in
[
1
,
2
]:
raise
ValueError
(
'Only Frequency and Time masking are supported'
)
raise
ValueError
(
'Only Frequency and Time masking are supported'
)
...
@@ -999,12 +999,12 @@ def compute_deltas(
...
@@ -999,12 +999,12 @@ def compute_deltas(
:math:`N` is ``(win_length-1)//2``.
:math:`N` is ``(win_length-1)//2``.
Args:
Args:
specgram (Tensor): Tensor of audio of dimension (..., freq, time)
specgram (Tensor): Tensor of audio of dimension
`
(..., freq, time)
`
win_length (int, optional): The window length used for computing delta (Default: ``5``)
win_length (int, optional): The window length used for computing delta (Default: ``5``)
mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Returns:
Returns:
Tensor: Tensor of deltas of dimension (..., freq, time)
Tensor: Tensor of deltas of dimension
`
(..., freq, time)
`
Example
Example
>>> specgram = torch.randn(1, 40, 1000)
>>> specgram = torch.randn(1, 40, 1000)
...
@@ -1172,7 +1172,7 @@ def detect_pitch_frequency(
...
@@ -1172,7 +1172,7 @@ def detect_pitch_frequency(
It is implemented using normalized cross-correlation function and median smoothing.
It is implemented using normalized cross-correlation function and median smoothing.
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time)
waveform (Tensor): Tensor of audio of dimension
`
(..., freq, time)
`
sample_rate (int): The sample rate of the waveform (Hz)
sample_rate (int): The sample rate of the waveform (Hz)
frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
...
@@ -1180,7 +1180,7 @@ def detect_pitch_frequency(
...
@@ -1180,7 +1180,7 @@ def detect_pitch_frequency(
freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Returns:
Returns:
Tensor: Tensor of freq of dimension (..., frame)
Tensor: Tensor of freq of dimension
`
(..., frame)
`
"""
"""
# pack batch
# pack batch
shape
=
list
(
waveform
.
size
())
shape
=
list
(
waveform
.
size
())
...
@@ -1211,7 +1211,7 @@ def sliding_window_cmn(
...
@@ -1211,7 +1211,7 @@ def sliding_window_cmn(
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Args:
Args:
specgram (Tensor): Tensor of audio of dimension (..., time, freq)
specgram (Tensor): Tensor of audio of dimension
`
(..., time, freq)
`
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
Only applicable if center == false, ignored if center==true (int, default = 100)
...
@@ -1220,7 +1220,7 @@ def sliding_window_cmn(
...
@@ -1220,7 +1220,7 @@ def sliding_window_cmn(
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
Returns:
Returns:
Tensor: Tensor matching input shape (..., freq, time)
Tensor: Tensor matching input shape
`
(..., freq, time)
`
"""
"""
input_shape
=
specgram
.
shape
input_shape
=
specgram
.
shape
num_frames
,
num_feats
=
input_shape
[
-
2
:]
num_frames
,
num_feats
=
input_shape
[
-
2
:]
...
@@ -1307,7 +1307,7 @@ def spectral_centroid(
...
@@ -1307,7 +1307,7 @@ def spectral_centroid(
frequency values, weighted by their magnitude.
frequency values, weighted by their magnitude.
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
sample_rate (int): Sample rate of the audio waveform
sample_rate (int): Sample rate of the audio waveform
pad (int): Two sided padding of signal
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
window (Tensor): Window tensor that is applied/multiplied to each frame/window
...
@@ -1316,7 +1316,7 @@ def spectral_centroid(
...
@@ -1316,7 +1316,7 @@ def spectral_centroid(
win_length (int): Window size
win_length (int): Window size
Returns:
Returns:
Tensor: Dimension (..., time)
Tensor: Dimension
`
(..., time)
`
"""
"""
specgram
=
spectrogram
(
waveform
,
pad
=
pad
,
window
=
window
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
specgram
=
spectrogram
(
waveform
,
pad
=
pad
,
window
=
window
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
power
=
1.
,
normalized
=
False
)
win_length
=
win_length
,
power
=
1.
,
normalized
=
False
)
...
@@ -1344,8 +1344,8 @@ def apply_codec(
...
@@ -1344,8 +1344,8 @@ def apply_codec(
sample_rate (int): Sample rate of the audio waveform.
sample_rate (int): Sample rate of the audio waveform.
format (str): File format.
format (str): File format.
channels_first (bool, optional):
channels_first (bool, optional):
When True, both the input and output Tensor have dimension `
`[
channel, time
]`
`.
When True, both the input and output Tensor have dimension `
(
channel, time
)
`.
Otherwise, they have dimension `
`[
time, channel
]`
`.
Otherwise, they have dimension `
(
time, channel
)
`.
compression (float or None, optional): Used for formats other than WAV.
compression (float or None, optional): Used for formats other than WAV.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
encoding (str or None, optional): Changes the encoding for the supported formats.
encoding (str or None, optional): Changes the encoding for the supported formats.
...
@@ -1355,7 +1355,7 @@ def apply_codec(
...
@@ -1355,7 +1355,7 @@ def apply_codec(
Returns:
Returns:
torch.Tensor: Resulting Tensor.
torch.Tensor: Resulting Tensor.
If ``channels_first=True``, it has `
`[
channel, time
]`
` else `
`[
time, channel
]`
`.
If ``channels_first=True``, it has `
(
channel, time
)
` else `
(
time, channel
)
`.
"""
"""
bytes
=
io
.
BytesIO
()
bytes
=
io
.
BytesIO
()
torchaudio
.
backend
.
sox_io_backend
.
save
(
bytes
,
torchaudio
.
backend
.
sox_io_backend
.
save
(
bytes
,
...
@@ -1453,7 +1453,7 @@ def compute_kaldi_pitch(
...
@@ -1453,7 +1453,7 @@ def compute_kaldi_pitch(
This makes different types of features give the same number of frames. (default: True)
This makes different types of features give the same number of frames. (default: True)
Returns:
Returns:
Tensor: Pitch feature. Shape:
`
`(batch, frames 2)`
`
where the last dimension
Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
corresponds to pitch and NCCF.
corresponds to pitch and NCCF.
"""
"""
shape
=
waveform
.
shape
shape
=
waveform
.
shape
...
@@ -1605,7 +1605,7 @@ def resample(
...
@@ -1605,7 +1605,7 @@ def resample(
more efficient computation if resampling multiple waveforms with the same resampling parameters.
more efficient computation if resampling multiple waveforms with the same resampling parameters.
Args:
Args:
waveform (Tensor): The input signal of dimension (..., time)
waveform (Tensor): The input signal of dimension
`
(..., time)
`
orig_freq (float): The original frequency of the signal
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
...
@@ -1617,7 +1617,7 @@ def resample(
...
@@ -1617,7 +1617,7 @@ def resample(
beta (float or None, optional): The shape parameter used for kaiser window.
beta (float or None, optional): The shape parameter used for kaiser window.
Returns:
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension
`
(..., time).
`
"""
"""
assert
orig_freq
>
0.0
and
new_freq
>
0.0
assert
orig_freq
>
0.0
and
new_freq
>
0.0
...
...
torchaudio/models/wav2vec2/components.py
View file @
21a0d29e
...
@@ -301,9 +301,9 @@ class FeedForward(Module):
...
@@ -301,9 +301,9 @@ class FeedForward(Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Args:
Args:
x (Tensor): shape:
`
`(batch, sequence_length, io_features)`
`
x (Tensor): shape: `(batch, sequence_length, io_features)`
Returns:
Returns:
x (Tensor): shape:
`
`(batch, sequence_length, io_features)`
`
x (Tensor): shape: `(batch, sequence_length, io_features)`
"""
"""
x
=
self
.
intermediate_dense
(
x
)
x
=
self
.
intermediate_dense
(
x
)
x
=
torch
.
nn
.
functional
.
gelu
(
x
)
x
=
torch
.
nn
.
functional
.
gelu
(
x
)
...
@@ -339,9 +339,9 @@ class EncoderLayer(Module):
...
@@ -339,9 +339,9 @@ class EncoderLayer(Module):
):
):
"""
"""
Args:
Args:
x (Tensor): shape:
`
`(batch, sequence_length, embed_dim)`
`
x (Tensor): shape: `(batch, sequence_length, embed_dim)`
attention_mask (Tensor or None, optional):
attention_mask (Tensor or None, optional):
shape:
`
`(batch, 1, sequence_length, sequence_length)`
`
shape: `(batch, 1, sequence_length, sequence_length)`
"""
"""
residual
=
x
residual
=
x
...
...
torchaudio/models/wav2vec2/model.py
View file @
21a0d29e
...
@@ -48,10 +48,10 @@ class Wav2Vec2Model(Module):
...
@@ -48,10 +48,10 @@ class Wav2Vec2Model(Module):
transformer block in encoder.
transformer block in encoder.
Args:
Args:
waveforms (Tensor): Audio tensor of shape
`
`(batch, frames)`
`
.
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional):
lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch.
Indicates the valid length of each audio sample in the batch.
Shape:
`
`(batch, )`
`
.
Shape: `(batch, )`.
num_layers (int or None, optional):
num_layers (int or None, optional):
If given, limit the number of intermediate layers to go through.
If given, limit the number of intermediate layers to go through.
Providing `1` will stop the computation after going through one
Providing `1` will stop the computation after going through one
...
@@ -62,9 +62,9 @@ class Wav2Vec2Model(Module):
...
@@ -62,9 +62,9 @@ class Wav2Vec2Model(Module):
List of Tensors and an optional Tensor:
List of Tensors and an optional Tensor:
List of Tensors
List of Tensors
Features from requested layers.
Features from requested layers.
Each Tensor is of shape:
`
`(batch, frames, feature dimention)`
`
Each Tensor is of shape: `(batch, frames, feature dimention)`
Tensor or None
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape
`
`(batch, )`
`
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch.
is retuned. It indicates the valid length of each feature in the batch.
"""
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
...
@@ -79,18 +79,18 @@ class Wav2Vec2Model(Module):
...
@@ -79,18 +79,18 @@ class Wav2Vec2Model(Module):
"""Compute the sequence of probability distribution over labels.
"""Compute the sequence of probability distribution over labels.
Args:
Args:
waveforms (Tensor): Audio tensor of shape
`
`(batch, frames)`
`
.
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional):
lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch.
Indicates the valid length of each audio sample in the batch.
Shape:
`
`(batch, )`
`
.
Shape: `(batch, )`.
Returns:
Returns:
Tensor and an optional Tensor:
Tensor and an optional Tensor:
Tensor
Tensor
The sequences of probability distribution (in logit) over labels.
The sequences of probability distribution (in logit) over labels.
Shape:
`
`(batch, frames, num labels)`
`
.
Shape: `(batch, frames, num labels)`.
Tensor or None
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape
`
`(batch, )`
`
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch.
is retuned. It indicates the valid length of each feature in the batch.
"""
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
...
...
torchaudio/transforms.py
View file @
21a0d29e
...
@@ -971,8 +971,8 @@ class TimeStretch(torch.nn.Module):
...
@@ -971,8 +971,8 @@ class TimeStretch(torch.nn.Module):
r
"""
r
"""
Args:
Args:
complex_specgrams (Tensor):
complex_specgrams (Tensor):
Either a real tensor of dimension of
`
`(..., freq, num_frame, complex=2)`
`
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension
`
`(..., freq, num_frame)`
`
with complex dtype.
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch.
overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
...
@@ -1018,10 +1018,10 @@ class Fade(torch.nn.Module):
...
@@ -1018,10 +1018,10 @@ class Fade(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
"""
waveform_length
=
waveform
.
size
()[
-
1
]
waveform_length
=
waveform
.
size
()[
-
1
]
device
=
waveform
.
device
device
=
waveform
.
device
...
@@ -1092,11 +1092,11 @@ class _AxisMasking(torch.nn.Module):
...
@@ -1092,11 +1092,11 @@ class _AxisMasking(torch.nn.Module):
def
forward
(
self
,
specgram
:
Tensor
,
mask_value
:
float
=
0.
)
->
Tensor
:
def
forward
(
self
,
specgram
:
Tensor
,
mask_value
:
float
=
0.
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
specgram (Tensor): Tensor of dimension (..., freq, time).
specgram (Tensor): Tensor of dimension
`
(..., freq, time)
`
.
mask_value (float): Value to assign to the masked columns.
mask_value (float): Value to assign to the masked columns.
Returns:
Returns:
Tensor: Masked spectrogram of dimensions (..., freq, time).
Tensor: Masked spectrogram of dimensions
`
(..., freq, time)
`
.
"""
"""
# if iid_masks flag marked and specgram has a batch dimension
# if iid_masks flag marked and specgram has a batch dimension
if
self
.
iid_masks
and
specgram
.
dim
()
==
4
:
if
self
.
iid_masks
and
specgram
.
dim
()
==
4
:
...
@@ -1157,10 +1157,10 @@ class Vol(torch.nn.Module):
...
@@ -1157,10 +1157,10 @@ class Vol(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
"""
if
self
.
gain_type
==
"amplitude"
:
if
self
.
gain_type
==
"amplitude"
:
waveform
=
waveform
*
self
.
gain
waveform
=
waveform
*
self
.
gain
...
@@ -1201,10 +1201,10 @@ class SlidingWindowCmn(torch.nn.Module):
...
@@ -1201,10 +1201,10 @@ class SlidingWindowCmn(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
"""
cmn_waveform
=
F
.
sliding_window_cmn
(
cmn_waveform
=
F
.
sliding_window_cmn
(
waveform
,
self
.
cmn_window
,
self
.
min_cmn_window
,
self
.
center
,
self
.
norm_vars
)
waveform
,
self
.
cmn_window
,
self
.
min_cmn_window
,
self
.
center
,
self
.
norm_vars
)
...
@@ -1374,10 +1374,10 @@ class SpectralCentroid(torch.nn.Module):
...
@@ -1374,10 +1374,10 @@ class SpectralCentroid(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Returns:
Tensor: Spectral Centroid of size (..., time).
Tensor: Spectral Centroid of size
`
(..., time)
`
.
"""
"""
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
...
@@ -1428,7 +1428,7 @@ class PitchShift(torch.nn.Module):
...
@@ -1428,7 +1428,7 @@ class PitchShift(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
r
"""
Args:
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
Tensor: The pitch-shifted audio of shape `(..., time)`.
...
@@ -1513,7 +1513,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch
...
@@ -1513,7 +1513,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch
r
"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
r
"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
Args:
input (torch.Tensor): Tensor of dimension (..., channel, channel)
input (torch.Tensor): Tensor of dimension
`
(..., channel, channel)
`
dim1 (int, optional): the first dimension of the diagonal matrix
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
dim2 (int, optional): the second dimension of the diagonal matrix
...
@@ -1548,14 +1548,14 @@ class PSD(torch.nn.Module):
...
@@ -1548,14 +1548,14 @@ class PSD(torch.nn.Module):
"""
"""
Args:
Args:
specgram (torch.Tensor): multi-channel complex-valued STFT matrix.
specgram (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` or
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False`` or
of dimension (..., channel, freq, time) if multi_mask is ``True``
of dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
Returns:
Returns:
torch.Tensor: PSD matrix of the input STFT matrix.
torch.Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
"""
"""
# outer product:
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
...
@@ -1804,11 +1804,11 @@ class MVDR(torch.nn.Module):
...
@@ -1804,11 +1804,11 @@ class MVDR(torch.nn.Module):
Args:
Args:
psd_s (torch.tensor): covariance matrix of speech
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
Returns:
Returns:
torch.Tensor: the enhanced STFT
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
Tensor of dimension
`
(..., freq, channel, 1)
`
"""
"""
w
,
v
=
torch
.
linalg
.
eig
(
psd_s
)
# (..., freq, channel, channel)
w
,
v
=
torch
.
linalg
.
eig
(
psd_s
)
# (..., freq, channel, channel)
_
,
indices
=
torch
.
max
(
w
.
abs
(),
dim
=-
1
,
keepdim
=
True
)
_
,
indices
=
torch
.
max
(
w
.
abs
(),
dim
=-
1
,
keepdim
=
True
)
...
@@ -1826,14 +1826,14 @@ class MVDR(torch.nn.Module):
...
@@ -1826,14 +1826,14 @@ class MVDR(torch.nn.Module):
Args:
Args:
psd_s (torch.tensor): covariance matrix of speech
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
psd_n (torch.Tensor): covariance matrix of noise
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
reference_vector (torch.Tensor): one-hot reference channel matrix
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
Returns:
torch.Tensor: the enhanced STFT
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
Tensor of dimension
`
(..., freq, channel, 1)
`
"""
"""
phi
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
phi
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
stv
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
phi
,
reference_vector
])
stv
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
phi
,
reference_vector
])
...
@@ -1850,13 +1850,13 @@ class MVDR(torch.nn.Module):
...
@@ -1850,13 +1850,13 @@ class MVDR(torch.nn.Module):
r
"""Apply the beamforming weight to the noisy STFT
r
"""Apply the beamforming weight to the noisy STFT
Args:
Args:
specgram (torch.tensor): multi-channel noisy STFT
specgram (torch.tensor): multi-channel noisy STFT
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
beamform_vector (torch.Tensor): beamforming weight matrix
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension (..., freq, channel)
Tensor of dimension
`
(..., freq, channel)
`
Returns:
Returns:
torch.Tensor: the enhanced STFT
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, time)
Tensor of dimension
`
(..., freq, time)
`
"""
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced
=
torch
.
einsum
(
"...fc,...cft->...ft"
,
[
beamform_vector
.
conj
(),
specgram
])
specgram_enhanced
=
torch
.
einsum
(
"...fc,...cft->...ft"
,
[
beamform_vector
.
conj
(),
specgram
])
...
@@ -1897,18 +1897,18 @@ class MVDR(torch.nn.Module):
...
@@ -1897,18 +1897,18 @@ class MVDR(torch.nn.Module):
Args:
Args:
specgram (torch.Tensor): the multi-channel STF of the noisy speech.
specgram (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
mask_s (torch.Tensor): Time-Frequency mask of target speech.
mask_s (torch.Tensor): Time-Frequency mask of target speech.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
or or dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
or or dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
(Default: None)
(Default: None)
Returns:
Returns:
torch.Tensor: The single-channel STFT of the enhanced speech.
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension (..., freq, time)
Tensor of dimension
`
(..., freq, time)
`
"""
"""
if
specgram
.
ndim
<
3
:
if
specgram
.
ndim
<
3
:
raise
ValueError
(
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment