Commit 4b021ae3 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add RTFMVDR module (#2368)

Summary:
Add a new design of MVDR module.
The RTFMVDR module supports the method based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
The input arguments are:
- multi-channel spectrum.
- RTF vector of the target speech
- PSD matrix of noise.
- reference channel in the microphone array.
- diagonal_loading option to enable or disable diagonal loading in matrix inverse computation.
- diag_eps for computing the inverse of the matrix.
- eps for computing the beamforming weight.
The output of the module is the single-channel complex-valued spectrum for the enhanced speech.

Pull Request resolved: https://github.com/pytorch/audio/pull/2368

Reviewed By: carolineechen

Differential Revision: D36214940

Pulled By: nateanl

fbshipit-source-id: 5f29f778663c96591e1b520b15f7876d07116937
parent da1e83cc
......@@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward
:hidden:`RTFMVDR`
-----------------
.. autoclass:: RTFMVDR
.. automethod:: forward
:hidden:`SoudenMVDR`
--------------------
......
......@@ -303,6 +303,16 @@ class AutogradTestMixin(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask_s, mask_n])
def test_rtf_mvdr(self):
transform = T.RTFMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
specgram = get_spectrogram(waveform, n_fft=400)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
self.assert_grad(transform, [specgram, rtf, psd_n, reference_channel])
def test_souden_mvdr(self):
transform = T.SoudenMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
......
......@@ -220,6 +220,25 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assertEqual(computed, expected)
def test_rtf_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1]
specgram = specgram.reshape(batch_size, channel, freq, time)
rtf = torch.rand(batch_size, freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
transform = T.RTFMVDR()
# Single then transform then batch
expected = [transform(specgram[i], rtf[i], psd_n[i], reference_channel) for i in range(batch_size)]
expected = torch.stack(expected)
# Batch then transform
computed = transform(specgram, rtf, psd_n, reference_channel)
self.assertEqual(computed, expected)
def test_souden_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
......
......@@ -175,6 +175,15 @@ class Transforms(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)
def test_rtf_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=self.complex_dtype, device=self.device)
psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
reference_channel = 0
self._assert_consistency_complex(T.RTFMVDR(), specgram, rtf, psd_n, reference_channel)
def test_souden_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
......
......@@ -1892,35 +1892,39 @@ def mvdr_weights_rtf(
.. properties:: Autograd TorchScript
Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`,
the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
:math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
where :math:`\bm{v}` is the RTF or the steering vector.
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
Args:
rtf (Tensor): The complex-valued RTF vector of target speech.
Tensor of dimension `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor, optional): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
If a non-None value is given, the MVDR weights will be normalized by ``rtf[..., reference_channel].conj()``
(Default: ``None``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
psd_n = _tik_reg(psd_n, reg=diag_eps)
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
......
......@@ -24,6 +24,7 @@ from ._transforms import (
RNNTLoss,
PSD,
MVDR,
RTFMVDR,
SoudenMVDR,
)
......@@ -46,6 +47,7 @@ __all__ = [
"PSD",
"PitchShift",
"RNNTLoss",
"RTFMVDR",
"Resample",
"SlidingWindowCmn",
"SoudenMVDR",
......
......@@ -2091,6 +2091,70 @@ class MVDR(torch.nn.Module):
return specgram_enhanced
class RTFMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
The beamforming weight is computed by:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
"""
def forward(
self,
specgram: Tensor,
rtf: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
"""
Args:
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
return spectrum_enhanced
class SoudenMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment