Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
ac97ad82
"vscode:/vscode.git/clone" did not exist on "cf48603943ba2678349ba0dadd06a50099ae5eec"
Unverified
Commit
ac97ad82
authored
Sep 20, 2021
by
nateanl
Committed by
GitHub
Sep 20, 2021
Browse files
Move MVDR and PSD modules to transforms (#1771)
parent
88ca1e05
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
441 additions
and
0 deletions
+441
-0
torchaudio/transforms.py
torchaudio/transforms.py
+441
-0
No files found.
torchaudio/transforms.py
View file @
ac97ad82
...
@@ -38,6 +38,8 @@ __all__ = [
...
@@ -38,6 +38,8 @@ __all__ = [
'ComputeDeltas'
,
'ComputeDeltas'
,
'PitchShift'
,
'PitchShift'
,
'RNNTLoss'
,
'RNNTLoss'
,
'PSD'
,
'MVDR'
,
]
]
...
@@ -1488,3 +1490,442 @@ class RNNTLoss(torch.nn.Module):
...
@@ -1488,3 +1490,442 @@ class RNNTLoss(torch.nn.Module):
self
.
clamp
,
self
.
clamp
,
self
.
reduction
self
.
reduction
)
)
def
_get_mat_trace
(
input
:
torch
.
Tensor
,
dim1
:
int
=
-
1
,
dim2
:
int
=
-
2
)
->
torch
.
Tensor
:
r
"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension (..., channel, channel)
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
Returns:
torch.Tensor: trace of the input Tensor
"""
assert
input
.
ndim
>=
2
,
"The dimension of the tensor must be at least 2."
assert
input
.
shape
[
dim1
]
==
input
.
shape
[
dim2
],
\
"The size of ``dim1`` and ``dim2`` must be the same."
input
=
torch
.
diagonal
(
input
,
0
,
dim1
=
dim1
,
dim2
=
dim2
)
return
input
.
sum
(
dim
=-
1
)
class
PSD
(
torch
.
nn
.
Module
):
r
"""Compute cross-channel power spectral density (PSD) matrix.
Args:
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``)
normalize (bool, optional): whether normalize the mask along the time dimension.
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-15)
"""
def
__init__
(
self
,
multi_mask
:
bool
=
False
,
normalize
:
bool
=
True
,
eps
:
float
=
1e-15
):
super
().
__init__
()
self
.
multi_mask
=
multi_mask
self
.
normalize
=
normalize
self
.
eps
=
eps
def
forward
(
self
,
specgram
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
):
"""
Args:
specgram (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension (..., channel, freq, time)
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` or
of dimension (..., channel, freq, time) if multi_mask is ``True``
Returns:
torch.Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension (..., freq, channel, channel)
"""
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
psd
=
torch
.
einsum
(
"...cft,...eft->...ftce"
,
[
specgram
,
specgram
.
conj
()])
if
mask
is
not
None
:
if
self
.
multi_mask
:
# Averaging mask along channel dimension
mask
=
mask
.
mean
(
dim
=-
3
)
# (..., freq, time)
# Normalized mask along time dimension:
if
self
.
normalize
:
mask
=
mask
/
(
mask
.
sum
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
psd
=
psd
*
mask
.
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
psd
=
psd
.
sum
(
dim
=-
3
)
return
psd
class
MVDR
(
torch
.
nn
.
Module
):
"""MVDR module that performs MVDR beamforming with Time-Frequency masks.
Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py
We provide three solutions of MVDR beamforming. One is based on *reference channel selection*
[:footcite:`souden2009optimal`].
The other two solutions are based on the steering vector. We apply either *eigenvalue decomposition*
[:footcite:`higuchi2016robust`] or the *power method* [:footcite:`mises1929praktische`] to get the
steering vector from the PSD matrices.
For online streaming audio, we provide a *recursive method* [:footcite:`higuchi2017online`] to update the
PSD matrices of speech and noise, respectively.
Args:
ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``)
solution (str, optional): the solution to get MVDR weight.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``)
diag_loading (bool, optional): whether apply diagonal loading on the psd matrix of noise.
(Default: ``True``)
diag_eps (float, optional): the coefficient multipied to the identity matrix for diagonal loading.
(Default: 1e-7)
online (bool, optional): whether to update the mvdr vector based on the previous psd matrices.
(Default: ``False``)
Note:
The MVDR Module requires the input STFT to be double precision (``torch.complex128`` or ``torch.cdouble``),
to improve the numerical stability. You can downgrade the precision to ``torch.float`` after generating the
enhanced waveform for ASR joint training.
Note:
If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
"""
def
__init__
(
self
,
ref_channel
:
int
=
0
,
solution
:
str
=
"ref_channel"
,
multi_mask
:
bool
=
False
,
diag_loading
:
bool
=
True
,
diag_eps
:
float
=
1e-7
,
online
:
bool
=
False
,
):
super
().
__init__
()
assert
solution
in
[
"ref_channel"
,
"stv_evd"
,
"stv_power"
],
\
"Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
self
.
ref_channel
=
ref_channel
self
.
solution
=
solution
self
.
multi_mask
=
multi_mask
self
.
diag_loading
=
diag_loading
self
.
diag_eps
=
diag_eps
self
.
online
=
online
self
.
psd
=
PSD
(
multi_mask
)
psd_s
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
psd_n
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
mask_sum_s
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
mask_sum_n
:
torch
.
Tensor
=
torch
.
zeros
(
1
)
self
.
register_buffer
(
'psd_s'
,
psd_s
)
self
.
register_buffer
(
'psd_n'
,
psd_n
)
self
.
register_buffer
(
'mask_sum_s'
,
mask_sum_s
)
self
.
register_buffer
(
'mask_sum_n'
,
mask_sum_n
)
def
_get_updated_mvdr_vector
(
self
,
psd_s
:
torch
.
Tensor
,
psd_n
:
torch
.
Tensor
,
mask_s
:
torch
.
Tensor
,
mask_n
:
torch
.
Tensor
,
reference_vector
:
torch
.
Tensor
,
solution
:
str
=
'ref_channel'
,
diagonal_loading
:
bool
=
True
,
diag_eps
:
float
=
1e-7
,
eps
:
float
=
1e-8
,
)
->
torch
.
Tensor
:
r
"""Recursively update the MVDR beamforming vector.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
mask_s (torch.Tensor): T-F mask of target speech
mask_n (torch.Tensor): T-F mask of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8)
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if
self
.
multi_mask
:
# Averaging mask along channel dimension
mask_s
=
mask_s
.
mean
(
dim
=-
3
)
# (..., freq, time)
mask_n
=
mask_n
.
mean
(
dim
=-
3
)
# (..., freq, time)
if
self
.
psd_s
.
ndim
==
1
:
self
.
psd_s
=
psd_s
self
.
psd_n
=
psd_n
self
.
mask_sum_s
=
mask_s
.
sum
(
dim
=-
1
)
self
.
mask_sum_n
=
mask_n
.
sum
(
dim
=-
1
)
return
self
.
_get_mvdr_vector
(
psd_s
,
psd_n
,
reference_vector
,
solution
,
diagonal_loading
,
diag_eps
,
eps
)
else
:
psd_s
=
self
.
_get_updated_psd_speech
(
psd_s
,
mask_s
)
psd_n
=
self
.
_get_updated_psd_noise
(
psd_n
,
mask_n
)
self
.
psd_s
=
psd_s
self
.
psd_n
=
psd_n
self
.
mask_sum_s
=
self
.
mask_sum_s
+
mask_s
.
sum
(
dim
=-
1
)
self
.
mask_sum_n
=
self
.
mask_sum_n
+
mask_n
.
sum
(
dim
=-
1
)
return
self
.
_get_mvdr_vector
(
psd_s
,
psd_n
,
reference_vector
,
solution
,
diagonal_loading
,
diag_eps
,
eps
)
def
_get_updated_psd_speech
(
self
,
psd_s
:
torch
.
Tensor
,
mask_s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Update psd of speech recursively.
Args:
psd_s (torch.Tensor): psd matrix of target speech
mask_s (torch.Tensor): T-F mask of target speech
Returns:
torch.Tensor: the updated psd of speech
"""
numerator
=
self
.
mask_sum_s
/
(
self
.
mask_sum_s
+
mask_s
.
sum
(
dim
=-
1
))
denominator
=
1
/
(
self
.
mask_sum_s
+
mask_s
.
sum
(
dim
=-
1
))
psd_s
=
self
.
psd_s
*
numerator
[...,
None
,
None
]
+
psd_s
*
denominator
[...,
None
,
None
]
return
psd_s
def
_get_updated_psd_noise
(
self
,
psd_n
:
torch
.
Tensor
,
mask_n
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Update psd of noise recursively.
Args:
psd_n (torch.Tensor): psd matrix of target noise
mask_n (torch.Tensor): T-F mask of target noise
Returns:
torch.Tensor: the updated psd of noise
"""
numerator
=
self
.
mask_sum_n
/
(
self
.
mask_sum_n
+
mask_n
.
sum
(
dim
=-
1
))
denominator
=
1
/
(
self
.
mask_sum_n
+
mask_n
.
sum
(
dim
=-
1
))
psd_n
=
self
.
psd_n
*
numerator
[...,
None
,
None
]
+
psd_n
*
denominator
[...,
None
,
None
]
return
psd_n
def
_get_mvdr_vector
(
self
,
psd_s
:
torch
.
Tensor
,
psd_n
:
torch
.
Tensor
,
reference_vector
:
torch
.
Tensor
,
solution
:
str
=
'ref_channel'
,
diagonal_loading
:
bool
=
True
,
diag_eps
:
float
=
1e-7
,
eps
:
float
=
1e-8
,
)
->
torch
.
Tensor
:
r
"""Compute beamforming vector by the reference channel selection method.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-8
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if
diagonal_loading
:
psd_n
=
self
.
_tik_reg
(
psd_n
,
reg
=
diag_eps
,
eps
=
eps
)
if
solution
==
"ref_channel"
:
numerator
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws
=
numerator
/
(
_get_mat_trace
(
numerator
)[...,
None
,
None
]
+
eps
)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
ws
,
reference_vector
])
else
:
if
solution
==
"stv_evd"
:
stv
=
self
.
_get_steering_vector_evd
(
psd_s
)
else
:
stv
=
self
.
_get_steering_vector_power
(
psd_s
,
psd_n
,
reference_vector
)
# numerator = psd_n.inv() @ stv
numerator
=
torch
.
linalg
.
solve
(
psd_n
,
stv
).
squeeze
(
-
1
)
# (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator
=
torch
.
einsum
(
"...d,...d->..."
,
[
stv
.
conj
().
squeeze
(
-
1
),
numerator
])
# normalzie the numerator
scale
=
stv
.
squeeze
(
-
1
)[...,
self
.
ref_channel
,
None
].
conj
()
beamform_vector
=
numerator
*
scale
/
(
denominator
.
real
.
unsqueeze
(
-
1
)
+
eps
)
return
beamform_vector
def
_get_steering_vector_evd
(
self
,
psd_s
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Estimate the steering vector by eigenvalue decomposition.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
"""
w
,
v
=
torch
.
linalg
.
eig
(
psd_s
)
# (..., freq, channel, channel)
_
,
indices
=
torch
.
max
(
w
.
abs
(),
dim
=-
1
,
keepdim
=
True
)
indices
=
indices
.
unsqueeze
(
-
1
)
stv
=
v
.
gather
(
-
1
,
indices
.
expand
(
psd_s
.
shape
[:
-
1
]
+
(
1
,)))
# (..., freq, channel, 1)
return
stv
def
_get_steering_vector_power
(
self
,
psd_s
:
torch
.
Tensor
,
psd_n
:
torch
.
Tensor
,
reference_vector
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Estimate the steering vector by the power method.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension (..., freq, channel, channel)
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
"""
phi
=
torch
.
linalg
.
solve
(
psd_n
,
psd_s
)
# psd_n.inv() @ psd_s
stv
=
torch
.
einsum
(
"...fec,...c->...fe"
,
[
phi
,
reference_vector
])
stv
=
stv
.
unsqueeze
(
-
1
)
stv
=
torch
.
matmul
(
phi
,
stv
)
stv
=
torch
.
matmul
(
psd_s
,
stv
)
return
stv
def
_apply_beamforming_vector
(
self
,
specgram
:
torch
.
Tensor
,
beamform_vector
:
torch
.
Tensor
)
->
torch
.
Tensor
:
r
"""Apply the beamforming weight to the noisy STFT
Args:
specgram (torch.tensor): multi-channel noisy STFT
Tensor of dimension (..., channel, freq, time)
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension (..., freq, channel)
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, time)
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced
=
torch
.
einsum
(
"...fc,...cft->...ft"
,
[
beamform_vector
.
conj
(),
specgram
])
return
specgram_enhanced
def
_tik_reg
(
self
,
mat
:
torch
.
Tensor
,
reg
:
float
=
1e-7
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: 1e-8)
Returns:
torch.Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C
=
mat
.
size
(
-
1
)
eye
=
torch
.
eye
(
C
,
dtype
=
mat
.
dtype
,
device
=
mat
.
device
)
with
torch
.
no_grad
():
epsilon
=
_get_mat_trace
(
mat
).
real
[...,
None
,
None
]
*
reg
# in case that correlation_matrix is all-zero
epsilon
=
epsilon
+
eps
mat
=
mat
+
epsilon
*
eye
[...,
:,
:]
return
mat
def
forward
(
self
,
specgram
:
torch
.
Tensor
,
mask_s
:
torch
.
Tensor
,
mask_n
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Perform MVDR beamforming.
Args:
specgram (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension (..., channel, freq, time)
mask_s (torch.Tensor): Time-Frequency mask of target speech.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
(Default: None)
Returns:
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension (..., freq, time)
"""
if
specgram
.
ndim
<
3
:
raise
ValueError
(
f
"Expected at least 3D tensor (..., channel, freq, time). Found:
{
specgram
.
shape
}
"
)
if
specgram
.
dtype
!=
torch
.
cdouble
:
raise
ValueError
(
f
"The type of ``specgram`` tensor must be ``torch.cdouble``. Found:
{
specgram
.
dtype
}
"
)
if
mask_n
is
None
:
warnings
.
warn
(
"``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``."
)
mask_n
=
1
-
mask_s
shape
=
specgram
.
size
()
# pack batch
specgram
=
specgram
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
if
self
.
multi_mask
:
mask_s
=
mask_s
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
mask_n
=
mask_n
.
reshape
(
-
1
,
shape
[
-
3
],
shape
[
-
2
],
shape
[
-
1
])
else
:
mask_s
=
mask_s
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
mask_n
=
mask_n
.
reshape
(
-
1
,
shape
[
-
2
],
shape
[
-
1
])
psd_s
=
self
.
psd
(
specgram
,
mask_s
)
# (..., freq, time, channel, channel)
psd_n
=
self
.
psd
(
specgram
,
mask_n
)
# (..., freq, time, channel, channel)
u
=
torch
.
zeros
(
specgram
.
size
()[:
-
2
],
device
=
specgram
.
device
,
dtype
=
torch
.
cdouble
)
# (..., channel)
u
[...,
self
.
ref_channel
].
fill_
(
1
)
if
self
.
online
:
w_mvdr
=
self
.
_get_updated_mvdr_vector
(
psd_s
,
psd_n
,
mask_s
,
mask_n
,
u
,
self
.
solution
,
self
.
diag_loading
,
self
.
diag_eps
)
else
:
w_mvdr
=
self
.
_get_mvdr_vector
(
psd_s
,
psd_n
,
u
,
self
.
solution
,
self
.
diag_loading
,
self
.
diag_eps
)
specgram_enhanced
=
self
.
_apply_beamforming_vector
(
specgram
,
w_mvdr
)
# unpack batch
specgram_enhanced
=
specgram_enhanced
.
reshape
(
shape
[:
-
3
]
+
shape
[
-
2
:])
return
specgram_enhanced
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment