Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
8f270d09
Commit
8f270d09
authored
Oct 07, 2021
by
Caroline Chen
Browse files
Standardize tensor shapes format in docs (#1838)
parent
dc0990c7
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
80 additions
and
80 deletions
+80
-80
torchaudio/functional/filtering.py
torchaudio/functional/filtering.py
+10
-10
torchaudio/functional/functional.py
torchaudio/functional/functional.py
+26
-26
torchaudio/models/wav2vec2/components.py
torchaudio/models/wav2vec2/components.py
+4
-4
torchaudio/models/wav2vec2/model.py
torchaudio/models/wav2vec2/model.py
+8
-8
torchaudio/transforms.py
torchaudio/transforms.py
+32
-32
No files found.
torchaudio/functional/filtering.py
View file @
8f270d09
...
...
@@ -650,20 +650,20 @@ def filtfilt(
Inspired by https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.filtfilt.html
Args:
waveform (Tensor): audio waveform of dimension of
`
`(..., time)`
`
. Must be normalized to -1 to 1.
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delay coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
Returns:
Tensor: Waveform with dimension of either
`
`(..., num_filters, time)`
`
if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or
`
`(..., time)`
`
otherwise.
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or `(..., time)` otherwise.
"""
forward_filtered
=
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
,
clamp
=
False
,
batching
=
True
)
backward_filtered
=
lfilter
(
...
...
@@ -970,13 +970,13 @@ def lfilter(
Using double precision could also minimize numerical precision errors.
Args:
waveform (Tensor): audio waveform of dimension of
`
`(..., time)`
`
. Must be normalized to -1 to 1.
waveform (Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1.
a_coeffs (Tensor): denominator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[a0, a1, a2, ...]``.
Must be same size as b_coeffs (pad with 0's as necessary).
b_coeffs (Tensor): numerator coefficients of difference equation of dimension of either
1D with shape
`
`(num_order + 1)`
`
or 2D with shape
`
`(num_filters, num_order + 1)`
`
.
1D with shape `(num_order + 1)` or 2D with shape `(num_filters, num_order + 1)`.
Lower delays coefficients are first, e.g. ``[b0, b1, b2, ...]``.
Must be same size as a_coeffs (pad with 0's as necessary).
clamp (bool, optional): If ``True``, clamp the output signal to be in the range [-1, 1] (Default: ``True``)
...
...
@@ -986,8 +986,8 @@ def lfilter(
a_coeffs[i], b_coeffs[i], clamp=clamp, batching=False)``. (Default: ``True``)
Returns:
Tensor: Waveform with dimension of either
`
`(..., num_filters, time)`
`
if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or
`
`(..., time)`
`
otherwise.
Tensor: Waveform with dimension of either `(..., num_filters, time)` if ``a_coeffs`` and ``b_coeffs``
are 2D Tensors, or `(..., time)` otherwise.
"""
assert
a_coeffs
.
size
()
==
b_coeffs
.
size
()
assert
a_coeffs
.
ndim
<=
2
...
...
torchaudio/functional/functional.py
View file @
8f270d09
...
...
@@ -62,7 +62,7 @@ def spectrogram(
The spectrogram can be either magnitude-only or complex.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
...
...
@@ -89,7 +89,7 @@ def spectrogram(
power spectrogram, which is a real-valued tensor.
Returns:
Tensor: Dimension (..., freq, time), freq is
Tensor: Dimension
`
(..., freq, time)
`
, freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
...
...
@@ -172,7 +172,7 @@ def inverse_spectrogram(
Default: ``True``
Returns:
Tensor: Dimension (..., time). Least squares estimation of the original signal.
Tensor: Dimension
`
(..., time)
`
. Least squares estimation of the original signal.
"""
if
spectrogram
.
dtype
==
torch
.
float32
or
spectrogram
.
dtype
==
torch
.
float64
:
...
...
@@ -246,7 +246,7 @@ def griffinlim(
and *Signal estimation from modified short-time Fourier transform* [:footcite:`1172092`].
Args:
specgram (Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames)
specgram (Tensor): A magnitude-only STFT spectrogram of dimension
`
(..., freq, frames)
`
where freq is ``n_fft // 2 + 1``.
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins
...
...
@@ -263,7 +263,7 @@ def griffinlim(
rand_init (bool): Initializes phase randomly if True, to zero otherwise.
Returns:
torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given.
torch.Tensor: waveform of
`
(..., time)
`
, where time equals the ``length`` parameter if given.
"""
assert
momentum
<
1
,
'momentum={} > 1 can be unstable'
.
format
(
momentum
)
assert
momentum
>=
0
,
'momentum={} < 0'
.
format
(
momentum
)
...
...
@@ -791,10 +791,10 @@ def phase_vocoder(
Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of
`
`(..., freq, num_frame, complex=2)`
`
or a tensor of dimension
`
`(..., freq, num_frame)`
`
with complex dtype.
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of (freq, 1)
phase_advance (Tensor): Expected phase advance in each bin. Dimension of
`
(freq, 1)
`
Returns:
Tensor:
...
...
@@ -907,13 +907,13 @@ def mask_along_axis_iid(
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
Args:
specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
specgrams (Tensor): Real spectrograms
`
(batch, channel, freq, time)
`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (2 -> frequency, 3 -> time)
Returns:
Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
Tensor: Masked spectrograms of dimensions
`
(batch, channel, freq, time)
`
"""
if
axis
not
in
[
2
,
3
]:
...
...
@@ -950,13 +950,13 @@ def mask_along_axis(
All examples will have the same mask interval.
Args:
specgram (Tensor): Real spectrogram (channel, freq, time)
specgram (Tensor): Real spectrogram
`
(channel, freq, time)
`
mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param]
mask_value (float): Value to assign to the masked columns
axis (int): Axis to apply masking on (1 -> frequency, 2 -> time)
Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time)
Tensor: Masked spectrogram of dimensions
`
(channel, freq, time)
`
"""
if
axis
not
in
[
1
,
2
]:
raise
ValueError
(
'Only Frequency and Time masking are supported'
)
...
...
@@ -999,12 +999,12 @@ def compute_deltas(
:math:`N` is ``(win_length-1)//2``.
Args:
specgram (Tensor): Tensor of audio of dimension (..., freq, time)
specgram (Tensor): Tensor of audio of dimension
`
(..., freq, time)
`
win_length (int, optional): The window length used for computing delta (Default: ``5``)
mode (str, optional): Mode parameter passed to padding (Default: ``"replicate"``)
Returns:
Tensor: Tensor of deltas of dimension (..., freq, time)
Tensor: Tensor of deltas of dimension
`
(..., freq, time)
`
Example
>>> specgram = torch.randn(1, 40, 1000)
...
...
@@ -1172,7 +1172,7 @@ def detect_pitch_frequency(
It is implemented using normalized cross-correlation function and median smoothing.
Args:
waveform (Tensor): Tensor of audio of dimension (..., freq, time)
waveform (Tensor): Tensor of audio of dimension
`
(..., freq, time)
`
sample_rate (int): The sample rate of the waveform (Hz)
frame_time (float, optional): Duration of a frame (Default: ``10 ** (-2)``).
win_length (int, optional): The window length for median smoothing (in number of frames) (Default: ``30``).
...
...
@@ -1180,7 +1180,7 @@ def detect_pitch_frequency(
freq_high (int, optional): Highest frequency that can be detected (Hz) (Default: ``3400``).
Returns:
Tensor: Tensor of freq of dimension (..., frame)
Tensor: Tensor of freq of dimension
`
(..., frame)
`
"""
# pack batch
shape
=
list
(
waveform
.
size
())
...
...
@@ -1211,7 +1211,7 @@ def sliding_window_cmn(
Apply sliding-window cepstral mean (and optionally variance) normalization per utterance.
Args:
specgram (Tensor): Tensor of audio of dimension (..., time, freq)
specgram (Tensor): Tensor of audio of dimension
`
(..., time, freq)
`
cmn_window (int, optional): Window in frames for running average CMN computation (int, default = 600)
min_cmn_window (int, optional): Minimum CMN window used at start of decoding (adds latency only at start).
Only applicable if center == false, ignored if center==true (int, default = 100)
...
...
@@ -1220,7 +1220,7 @@ def sliding_window_cmn(
norm_vars (bool, optional): If true, normalize variance to one. (bool, default = false)
Returns:
Tensor: Tensor matching input shape (..., freq, time)
Tensor: Tensor matching input shape
`
(..., freq, time)
`
"""
input_shape
=
specgram
.
shape
num_frames
,
num_feats
=
input_shape
[
-
2
:]
...
...
@@ -1307,7 +1307,7 @@ def spectral_centroid(
frequency values, weighted by their magnitude.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
sample_rate (int): Sample rate of the audio waveform
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
...
...
@@ -1316,7 +1316,7 @@ def spectral_centroid(
win_length (int): Window size
Returns:
Tensor: Dimension (..., time)
Tensor: Dimension
`
(..., time)
`
"""
specgram
=
spectrogram
(
waveform
,
pad
=
pad
,
window
=
window
,
n_fft
=
n_fft
,
hop_length
=
hop_length
,
win_length
=
win_length
,
power
=
1.
,
normalized
=
False
)
...
...
@@ -1344,8 +1344,8 @@ def apply_codec(
sample_rate (int): Sample rate of the audio waveform.
format (str): File format.
channels_first (bool, optional):
When True, both the input and output Tensor have dimension `
`[
channel, time
]`
`.
Otherwise, they have dimension `
`[
time, channel
]`
`.
When True, both the input and output Tensor have dimension `
(
channel, time
)
`.
Otherwise, they have dimension `
(
time, channel
)
`.
compression (float or None, optional): Used for formats other than WAV.
For more details see :py:func:`torchaudio.backend.sox_io_backend.save`.
encoding (str or None, optional): Changes the encoding for the supported formats.
...
...
@@ -1355,7 +1355,7 @@ def apply_codec(
Returns:
torch.Tensor: Resulting Tensor.
If ``channels_first=True``, it has `
`[
channel, time
]`
` else `
`[
time, channel
]`
`.
If ``channels_first=True``, it has `
(
channel, time
)
` else `
(
time, channel
)
`.
"""
bytes
=
io
.
BytesIO
()
torchaudio
.
backend
.
sox_io_backend
.
save
(
bytes
,
...
...
@@ -1453,7 +1453,7 @@ def compute_kaldi_pitch(
This makes different types of features give the same number of frames. (default: True)
Returns:
Tensor: Pitch feature. Shape:
`
`(batch, frames 2)`
`
where the last dimension
Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension
corresponds to pitch and NCCF.
"""
shape
=
waveform
.
shape
...
...
@@ -1605,7 +1605,7 @@ def resample(
more efficient computation if resampling multiple waveforms with the same resampling parameters.
Args:
waveform (Tensor): The input signal of dimension (..., time)
waveform (Tensor): The input signal of dimension
`
(..., time)
`
orig_freq (float): The original frequency of the signal
new_freq (float): The desired frequency
lowpass_filter_width (int, optional): Controls the sharpness of the filter, more == sharper
...
...
@@ -1617,7 +1617,7 @@ def resample(
beta (float or None, optional): The shape parameter used for kaiser window.
Returns:
Tensor: The waveform at the new frequency of dimension (..., time).
Tensor: The waveform at the new frequency of dimension
`
(..., time).
`
"""
assert
orig_freq
>
0.0
and
new_freq
>
0.0
...
...
torchaudio/models/wav2vec2/components.py
View file @
8f270d09
...
...
@@ -301,9 +301,9 @@ class FeedForward(Module):
def
forward
(
self
,
x
):
"""
Args:
x (Tensor): shape:
`
`(batch, sequence_length, io_features)`
`
x (Tensor): shape: `(batch, sequence_length, io_features)`
Returns:
x (Tensor): shape:
`
`(batch, sequence_length, io_features)`
`
x (Tensor): shape: `(batch, sequence_length, io_features)`
"""
x
=
self
.
intermediate_dense
(
x
)
x
=
torch
.
nn
.
functional
.
gelu
(
x
)
...
...
@@ -339,9 +339,9 @@ class EncoderLayer(Module):
):
"""
Args:
x (Tensor): shape:
`
`(batch, sequence_length, embed_dim)`
`
x (Tensor): shape: `(batch, sequence_length, embed_dim)`
attention_mask (Tensor or None, optional):
shape:
`
`(batch, 1, sequence_length, sequence_length)`
`
shape: `(batch, 1, sequence_length, sequence_length)`
"""
residual
=
x
...
...
torchaudio/models/wav2vec2/model.py
View file @
8f270d09
...
...
@@ -48,10 +48,10 @@ class Wav2Vec2Model(Module):
transformer block in encoder.
Args:
waveforms (Tensor): Audio tensor of shape
`
`(batch, frames)`
`
.
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch.
Shape:
`
`(batch, )`
`
.
Shape: `(batch, )`.
num_layers (int or None, optional):
If given, limit the number of intermediate layers to go through.
Providing `1` will stop the computation after going through one
...
...
@@ -62,9 +62,9 @@ class Wav2Vec2Model(Module):
List of Tensors and an optional Tensor:
List of Tensors
Features from requested layers.
Each Tensor is of shape:
`
`(batch, frames, feature dimention)`
`
Each Tensor is of shape: `(batch, frames, feature dimention)`
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape
`
`(batch, )`
`
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch.
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
...
...
@@ -79,18 +79,18 @@ class Wav2Vec2Model(Module):
"""Compute the sequence of probability distribution over labels.
Args:
waveforms (Tensor): Audio tensor of shape
`
`(batch, frames)`
`
.
waveforms (Tensor): Audio tensor of shape `(batch, frames)`.
lengths (Tensor or None, optional):
Indicates the valid length of each audio sample in the batch.
Shape:
`
`(batch, )`
`
.
Shape: `(batch, )`.
Returns:
Tensor and an optional Tensor:
Tensor
The sequences of probability distribution (in logit) over labels.
Shape:
`
`(batch, frames, num labels)`
`
.
Shape: `(batch, frames, num labels)`.
Tensor or None
If ``lengths`` argument was provided, a Tensor of shape
`
`(batch, )`
`
If ``lengths`` argument was provided, a Tensor of shape `(batch, )`
is retuned. It indicates the valid length of each feature in the batch.
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
...
...
torchaudio/transforms.py
View file @
8f270d09
...
...
@@ -971,8 +971,8 @@ class TimeStretch(torch.nn.Module):
r
"""
Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of
`
`(..., freq, num_frame, complex=2)`
`
or a tensor of dimension
`
`(..., freq, num_frame)`
`
with complex dtype.
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
...
...
@@ -1018,10 +1018,10 @@ class Fade(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
waveform_length
=
waveform
.
size
()[
-
1
]
device
=
waveform
.
device
...
...
@@ -1092,11 +1092,11 @@ class _AxisMasking(torch.nn.Module):
def
forward
(
self
,
specgram
:
Tensor
,
mask_value
:
float
=
0.
)
->
Tensor
:
r
"""
Args:
specgram (Tensor): Tensor of dimension (..., freq, time).
specgram (Tensor): Tensor of dimension
`
(..., freq, time)
`
.
mask_value (float): Value to assign to the masked columns.
Returns:
Tensor: Masked spectrogram of dimensions (..., freq, time).
Tensor: Masked spectrogram of dimensions
`
(..., freq, time)
`
.
"""
# if iid_masks flag marked and specgram has a batch dimension
if
self
.
iid_masks
and
specgram
.
dim
()
==
4
:
...
...
@@ -1157,10 +1157,10 @@ class Vol(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
if
self
.
gain_type
==
"amplitude"
:
waveform
=
waveform
*
self
.
gain
...
...
@@ -1201,10 +1201,10 @@ class SlidingWindowCmn(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Tensor of audio of dimension (..., time).
Tensor: Tensor of audio of dimension
`
(..., time)
`
.
"""
cmn_waveform
=
F
.
sliding_window_cmn
(
waveform
,
self
.
cmn_window
,
self
.
min_cmn_window
,
self
.
center
,
self
.
norm_vars
)
...
...
@@ -1374,10 +1374,10 @@ class SpectralCentroid(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: Spectral Centroid of size (..., time).
Tensor: Spectral Centroid of size
`
(..., time)
`
.
"""
return
F
.
spectral_centroid
(
waveform
,
self
.
sample_rate
,
self
.
pad
,
self
.
window
,
self
.
n_fft
,
self
.
hop_length
,
...
...
@@ -1428,7 +1428,7 @@ class PitchShift(torch.nn.Module):
def
forward
(
self
,
waveform
:
Tensor
)
->
Tensor
:
r
"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
waveform (Tensor): Tensor of audio of dimension
`
(..., time)
`
.
Returns:
Tensor: The pitch-shifted audio of shape `(..., time)`.
...
...
@@ -1513,7 +1513,7 @@ def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch
r
"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension (..., channel, channel)
input (torch.Tensor): Tensor of dimension
`
(..., channel, channel)
`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
...
...
@@ -1548,14 +1548,14 @@ class PSD(torch.nn.Module):
"""
Args:
specgram (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` or
of dimension (..., channel, freq, time) if multi_mask is ``True``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False`` or
of dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
Returns:
torch.Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
"""
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
...
...
@@ -1804,11 +1804,11 @@ class MVDR(torch.nn.Module):
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
Tensor of dimension
`
(..., freq, channel, 1)
`
"""
w
,
v
=
torch
.
linalg
.
eig
(
psd_s
)
# (..., freq, channel, channel)
_
,
indices
=
torch
.
max
(
w
.
abs
(),
dim
=-
1
,
keepdim
=
True
)
...
...
@@ -1826,14 +1826,14 @@ class MVDR(torch.nn.Module):
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension (..., freq, channel, channel)
Tensor of dimension
`
(..., freq, channel, channel)
`
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
Tensor of dimension
`
(..., freq, channel, 1)
`
"""
phi
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
stv
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
phi
,
reference_vector
])
...
...
@@ -1850,13 +1850,13 @@ class MVDR(torch.nn.Module):
r
"""Apply the beamforming weight to the noisy STFT
Args:
specgram (torch.tensor): multi-channel noisy STFT
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension (..., freq, channel)
Tensor of dimension
`
(..., freq, channel)
`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, time)
Tensor of dimension
`
(..., freq, time)
`
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced
=
torch
.
einsum
(
"...fc,...cft->...ft"
,
[
beamform_vector
.
conj
(),
specgram
])
...
...
@@ -1897,18 +1897,18 @@ class MVDR(torch.nn.Module):
Args:
specgram (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension (..., channel, freq, time)
Tensor of dimension
`
(..., channel, freq, time)
`
mask_s (torch.Tensor): Time-Frequency mask of target speech.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False``
or or dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
Tensor of dimension
`
(..., freq, time)
`
if multi_mask is ``False``
or or dimension
`
(..., channel, freq, time)
`
if multi_mask is ``True``
(Default: None)
Returns:
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension (..., freq, time)
Tensor of dimension
`
(..., freq, time)
`
"""
if
specgram
.
ndim
<
3
:
raise
ValueError
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment