Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
21a0d29e
Unverified
Commit
21a0d29e
authored
Oct 07, 2021
by
Caroline Chen
Committed by
GitHub
Oct 07, 2021
Browse files
Standardize tensor shapes format in docs (#1838)
parent
d857348f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
80 deletions
+80
-80
torchaudio/functional/filtering.py
torchaudio/functional/filtering.py
+10
-10
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+26
-26
torchaudio/models/wav2vec2/components.py
torchaudio/models/wav2vec2/components.py
+4
-4
torchaudio/models/wav2vec2/model.py
torchaudio/models/wav2vec2/model.py
+8
-8
torchaudio/transforms.py
torchaudio/transforms.py
+32
-32
No files found.
torchaudio/functional/filtering.py
View file @
21a0d29e
...
...
@@ -650,20 +650,20 @@ def filtfilt(
Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html
Args:
waveform (Tensor): audio waveform of dimension of
`
`(..., time)`
`
. Must be normalized to -1 to 1.
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of either
`
`(..., num_filters, time)`
`
if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or
`
`(..., time)`
`
otherwise.
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or `(..., time)` otherwise.
"""
forward_filtered
=
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
,
clamp
=
False
,
batching
=
True
)
backward_filtered
=
lfilter
(
...
...
@@ -970,13 +970,13 @@ def lfilter(
Using double precision could also minimize numerical precision errors.
Args:
waveform (Tensor): audio waveform of dimension of
`
`(..., time)`
`
. Must be normalized to -1 to 1.
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
...
...
@@ -986,8 +986,8 @@ def lfilter(
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
Returns:
Tensor: Waveform with dimension of either
`
`(..., num_filters, time)`
`
if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or
`
`(..., time)`
`
otherwise.
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or `(..., time)` otherwise.
"""
assert
a_coeffs
.
size
()
==
b_coeffs
.
size
()
assert
a_coeffs
.
ndim
<=
2
...
...
torchaudio/functional/functional.py
View file @
21a0d29e
...
...
@@ -62,7 +62,7 @@ def spectrogram(
The spectrogram can be either magnitude-only or complex.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
...
...
@@ -89,7 +89,7 @@ def spectrogram(
power spectrogram, which is a real-valued tensor.
Returns:
Tensor: Dimension (..., freq, time), freq is
Tensor: Dimension
`
(..., freq, time)
`
, freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
...
...
@@ -172,7 +172,7 @@ def inverse_spectrogram(
Default: ``True``
Returns:
Tensor: Dimension (..., time). Least squares estimation of the original signal.
Tensor: Dimension
`
(..., time)
`
. Least squares estimation of the original signal.
"""
if
spectrogram
.
dtype
==
torch
.
float32
or
spectrogram
.
dtype
==
torch
.
float64
:
...
...
@@ -246,7 +246,7 @@ def griffinlim(
and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
Args:
specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
specgram (Tensor): A magnitude-only STFT spectrogram of dimension
`
(..., freq, frames)
`
where freq is ``n_fft // 2 + 1``.
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
...
...
@@ -263,7 +263,7 @@ def griffinlim(
rand_init (bool): Initializes phase randomly if True, to zero otherwise.
Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
torch.Tensor: waveform of
`
(..., time)
`
, where time equals the ``length`` parameter if given.
"""
assert
momentum
<
1
,
'momentum={} > 1 can be unstable'
.
format
(
momentum
)
assert
momentum
>=
0
,
'momentum={} < 0'
.
format
(
momentum
)
...
...
@@ -791,10 +791,10 @@ def phase_vocoder(
Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of
`
`(..., freq, num_frame, complex=2)`
`
or a tensor of dimension
`
`(..., freq, num_frame)`
`
with complex dtype.
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
phase_advance (Tensor): Expected phase advance in each bin. Dimension of
`
(freq, 1)
`
Returns:
Tensor:
...
...
@@ -907,13 +907,13 @@ def mask_along_axis_iid(
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
Args:
specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
specgrams (Tensor): Real spectrograms
`
(batch, channel, freq, time)
`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns:
Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
Tensor: Masked spectrograms of dimensions
`
(batch, channel, freq, time)
`
"""
if
axis
not
in
[
2
,
3
]:
...
...
@@ -950,13 +950,13 @@ def mask_along_axis(
All examples will have the same mask interval.
Args:
specgram (Tensor): Real spectrogram (channel, freq, time)
specgram (Tensor): Real spectrogram
`
(channel, freq, time)
`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time)
Tensor: Masked spectrogram of dimensions
`
(channel, freq, time)
`
"""
if
axis
not
in
[
1
,
2
]:
raise
ValueError
(
'Only Frequency and Time masking are supported'
)
...
...
@@ -999,12 +999,12 @@ def compute_deltas(
:math:`N` is ``(win_length-1)//2``.
Args:
specgram (Tensor): Tensor of audio of dimension (..., freq, time)
specgram (Tensor): Tensor of audio of dimension
`
(..., freq, time)
`
win_length (int, optional): The window length used for computing delta (Default: ``5``)
mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Returns:
Tensor: Tensor of deltas of dimension (..., freq, time)
Tensor: Tensor of deltas of dimension
`
(..., freq, time)
`
Example
>>> specgram = torch.randn(1, 40, 1000)
...
...
@@ -1172,7 +1172,7 @@ def detect_pitch_frequency(
It is implemented using normalized cross-correlation function and median smoothing.
Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time)
waveform (Tensor): Tensor of audio of dimension
`
(..., freq, time)
`
sample_rate (int): The sample rate of the waveform (Hz)
frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
...
...
@@ -1180,7 +1180,7 @@ def detect_pitch_frequency(
freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Returns:
Tensor: Tensor of freq of dimension (..., frame)
Tensor: Tensor of freq of dimension
`
(..., frame)
`
"""
# pack batch
shape
=
list
(
waveform
.
size
())
...
...
@@ -1211,7 +1211,7 @@ def sliding_window_cmn(
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Args:
specgram (Tensor): Tensor of audio of dimension (..., time, freq)
specgram (Tensor): Tensor of audio of dimension
`
(..., time, freq)
`
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
...
...
@@ -1220,7 +1220,7 @@ def sliding_window_cmn(
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
Returns:
Tensor: Tensor matching input shape (..., freq, time)
Tensor: Tensor matching input shape
`
(..., freq, time)
`
"""
input_shape
=
specgram
.
shape
num_frames
,
num_feats
=
input_shape
[
-
2
:]
...
...
@@ -1307,7 +1307,7 @@ def spectral_centroid(
frequency values, weighted by their magnitude.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
sample_rate (int): Sample rate of the audio waveform
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
...
...
@@ -1316,7 +1316,7 @@ def spectral_centroid(
win_length (int): Window size
Returns:
Tensor: Dimension (..., time)
Tensor: Dimension
`
(..., time)
`
"""
specgram
=
spectrogram
(
waveform
,
pad
=
pad
,
window
=
window
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
power
=
1.
,
normalized
=
False
)
...
...
@@ -1344,8 +1344,8 @@ def apply_codec(
sample_rate (int): Sample rate of the audio waveform.
format (str): File format.
channels_first (bool, optional):
When True, both the input and output Tensor have dimension `
`[
channel, time
]`
`.
Otherwise, they have dimension `
`[
time, channel
]`
`.
When True, both the input and output Tensor have dimension `
(
channel, time
)
`.
Otherwise, they have dimension `
(
time, channel
)
`.
compression (float or None, optional): Used for formats other than WAV.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
encoding (str or None, optional): Changes the encoding for the supported formats.
...
...
@@ -1355,7 +1355,7 @@ def apply_codec(
Returns:
torch.Tensor: Resulting Tensor.
If ``channels_first=True``, it has `
`[
channel, time
]`
` else `
`[
time, channel
]`
`.
If ``channels_first=True``, it has `
(
channel, time
)
` else `
(
time, channel
)
`.
"""
bytes
=
io
.
BytesIO
()
torchaudio
.
backend
.
sox_io_backend
.
save
(
bytes
,
...
...
@@ -1453,7 +1453,7 @@ def compute_kaldi_pitch(
This makes different types of features give the same number of frames. (default: True)
Returns:
Tensor: Pitch feature. Shape:
`
`(batch, frames 2)`
`
where the last dimension
Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
corresponds to pitch and NCCF.
"""
shape
=
waveform
.
shape
...
...
@@ -1605,7 +1605,7 @@ def resample(
more efficient computation if resampling multiple waveforms with the same resampling parameters.
Args:
waveform (Tensor): The input signal of dimension (..., time)
waveform (Tensor): The input signal of dimension
`
(..., time)
`
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
...
...
@@ -1617,7 +1617,7 @@ def resample(
beta (float or None, optional): The shape parameter used for kaiser window.
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension
`
(..., time).
`
"""
assert
orig_freq
>
0.0
and
new_freq
>
0.0
...
...
torchaudio/models/wav2vec2/components.py
View file @
21a0d29e
...
...
@@ -301,9 +301,9 @@ class FeedForward(Module):
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): shape:
`
`(batch, sequence_length, io_features)`
`
x (Tensor): shape: `(batch, sequence_length, io_features)`
Returns:
x (Tensor): shape:
`
`(batch, sequence_length, io_features)`
`
x (Tensor): shape: `(batch, sequence_length, io_features)`
"""
x
=
self
.
intermediate_dense
(
x
)
x
=
torch
.
nn
.
functional
.
gelu
(
x
)
...
...
@@ -339,9 +339,9 @@ class EncoderLayer(Module):
):
"""
Args:
x (Tensor): shape:
`
`(batch, sequence_length, embed_dim)`
`
x (Tensor): shape: `(batch, sequence_length, embed_dim)`
attention_mask (Tensor or None, optional):
shape:
`
`(batch, 1, sequence_length, sequence_length)`
`
shape: `(batch, 1, sequence_length, sequence_length)`
"""
residual
=
x
...
...
torchaudio/models/wav2vec2/model.py
View file @
21a0d29e
...
...
@@ -48,10 +48,10 @@ class Wav2Vec2Model(Module):
transformer block in encoder.
Args:
waveforms (Tensor): Audio tensor of shape
`
`(batch, frames)`
`
.
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch.
Shape:
`
`(batch, )`
`
.
Shape: `(batch, )`.
num_layers (int or None, optional):
If given, limit the number of intermediate layers to go through.
Providing `1` will stop the computation after going through one
...
...
@@ -62,9 +62,9 @@ class Wav2Vec2Model(Module):
List of Tensors and an optional Tensor:
List of Tensors
Features from requested layers.
Each Tensor is of shape:
`
`(batch, frames, feature dimention)`
`
Each Tensor is of shape: `(batch, frames, feature dimention)`
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape
`
`(batch, )`
`
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch.
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
...
...
@@ -79,18 +79,18 @@ class Wav2Vec2Model(Module):
"""Compute the sequence of probability distribution over labels.
Args:
waveforms (Tensor): Audio tensor of shape
`
`(batch, frames)`
`
.
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch.
Shape:
`
`(batch, )`
`
.
Shape: `(batch, )`.
Returns:
Tensor and an optional Tensor:
Tensor
The sequences of probability distribution (in logit) over labels.
Shape:
`
`(batch, frames, num labels)`
`
.
Shape: `(batch, frames, num labels)`.
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape
`
`(batch, )`
`
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch.
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
...
...
torchaudio/transforms.py
View file @
21a0d29e
...
...
@@ -971,8 +971,8 @@ class TimeStretch(torch.nn.Module):
r
"""
Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of
`
`(..., freq, num_frame, complex=2)`
`
or a tensor of dimension
`
`(..., freq, num_frame)`
`
with complex dtype.
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
...
...
@@ -1018,10 +1018,10 @@ class Fade(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
waveform_length
=
waveform
.
size
()[
-
1
]
device
=
waveform
.
device
...
...
@@ -1092,11 +1092,11 @@ class _AxisMasking(torch.nn.Module):
def
forward
(
self
,
specgram
:
Tensor
,
mask_value
:
float
=
0.
)
->
Tensor
:
r
"""
Args:
specgram (Tensor): Tensor of dimension (..., freq, time).
specgram (Tensor): Tensor of dimension
`
(..., freq, time)
`
.
mask_value (float): Value to assign to the masked columns.
Returns:
Tensor: Masked spectrogram of dimensions (..., freq, time).
Tensor: Masked spectrogram of dimensions
`
(..., freq, time)
`
.
"""
# if iid_masks flag marked and specgram has a batch dimension
if
self
.
iid_masks
and
specgram
.
dim
()
==
4
:
...
...
@@ -1157,10 +1157,10 @@ class Vol(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
if
self
.
gain_type
==
"amplitude"
:
waveform
=
waveform
*
self
.
gain
...
...
@@ -1201,10 +1201,10 @@ class SlidingWindowCmn(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
cmn_waveform
=
F
.
sliding_window_cmn
(
waveform
,
self
.
cmn_window
,
self
.
min_cmn_window
,
self
.
center
,
self
.
norm_vars
)
...
...
@@ -1374,10 +1374,10 @@ class SpectralCentroid(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Spectral Centroid of size (..., time).
Tensor: Spectral Centroid of size
`
(..., time)
`
.
"""
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
...
...
@@ -1428,7 +1428,7 @@ class PitchShift(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
...
...
@@ -1513,7 +1513,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch
r
"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension (..., channel, channel)
input (torch.Tensor): Tensor of dimension
`
(..., channel, channel)
`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
...
...
@@ -1548,14 +1548,14 @@ class PSD(torch.nn.Module):
"""
Args:
specgram (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` or
of dimension (..., channel, freq, time) if multi_mask is ``True``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False`` or
of dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
Returns:
torch.Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
"""
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
...
...
@@ -1804,11 +1804,11 @@ class MVDR(torch.nn.Module):
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
Tensor of dimension
`
(..., freq, channel, 1)
`
"""
w
,
v
=
torch
.
linalg
.
eig
(
psd_s
)
# (..., freq, channel, channel)
_
,
indices
=
torch
.
max
(
w
.
abs
(),
dim
=-
1
,
keepdim
=
True
)
...
...
@@ -1826,14 +1826,14 @@ class MVDR(torch.nn.Module):
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
Tensor of dimension
`
(..., freq, channel, 1)
`
"""
phi
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
stv
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
phi
,
reference_vector
])
...
...
@@ -1850,13 +1850,13 @@ class MVDR(torch.nn.Module):
r
"""Apply the beamforming weight to the noisy STFT
Args:
specgram (torch.tensor): multi-channel noisy STFT
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension (..., freq, channel)
Tensor of dimension
`
(..., freq, channel)
`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, time)
Tensor of dimension
`
(..., freq, time)
`
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced
=
torch
.
einsum
(
"...fc,...cft->...ft"
,
[
beamform_vector
.
conj
(),
specgram
])
...
...
@@ -1897,18 +1897,18 @@ class MVDR(torch.nn.Module):
Args:
specgram (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
mask_s (torch.Tensor): Time-Frequency mask of target speech.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False``
or or dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False``
or or dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
(Default: None)
Returns:
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension (..., freq, time)
Tensor of dimension
`
(..., freq, time)
`
"""
if
specgram
.
ndim
<
3
:
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment