"vscode:/vscode.git/clone" did not exist on "1a47a44dc6f577fecb80303182d4d855b5cb674b"
Commit 4b021ae3 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add RTFMVDR module (#2368)

Summary:
Add a new design of MVDR module.
The RTFMVDR module supports the method based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
The input arguments are:
- multi-channel spectrum.
- RTF vector of the target speech
- PSD matrix of noise.
- reference channel in the microphone array.
- diagonal_loading option to enable or disable diagonal loading in matrix inverse computation.
- diag_eps for computing the inverse of the matrix.
- eps for computing the beamforming weight.
The output of the module is the single-channel complex-valued spectrum for the enhanced speech.

Pull Request resolved: https://github.com/pytorch/audio/pull/2368

Reviewed By: carolineechen

Differential Revision: D36214940

Pulled By: nateanl

fbshipit-source-id: 5f29f778663c96591e1b520b15f7876d07116937
parent da1e83cc
...@@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward .. automethod:: forward
:hidden:`RTFMVDR`
-----------------
.. autoclass:: RTFMVDR
.. automethod:: forward
:hidden:`SoudenMVDR` :hidden:`SoudenMVDR`
-------------------- --------------------
......
...@@ -303,6 +303,16 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -303,6 +303,16 @@ class AutogradTestMixin(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:]) mask_n = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask_s, mask_n]) self.assert_grad(transform, [spectrogram, mask_s, mask_n])
def test_rtf_mvdr(self):
transform = T.RTFMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
specgram = get_spectrogram(waveform, n_fft=400)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
self.assert_grad(transform, [specgram, rtf, psd_n, reference_channel])
def test_souden_mvdr(self): def test_souden_mvdr(self):
transform = T.SoudenMVDR() transform = T.SoudenMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2) waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
......
...@@ -220,6 +220,25 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -220,6 +220,25 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_rtf_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1]
specgram = specgram.reshape(batch_size, channel, freq, time)
rtf = torch.rand(batch_size, freq, channel, dtype=torch.cfloat)
psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
transform = T.RTFMVDR()
# Single then transform then batch
expected = [transform(specgram[i], rtf[i], psd_n[i], reference_channel) for i in range(batch_size)]
expected = torch.stack(expected)
# Batch then transform
computed = transform(specgram, rtf, psd_n, reference_channel)
self.assertEqual(computed, expected)
def test_souden_mvdr(self): def test_souden_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6) waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400) specgram = common_utils.get_spectrogram(waveform, n_fft=400)
......
...@@ -175,6 +175,15 @@ class Transforms(TestBaseMixin): ...@@ -175,6 +175,15 @@ class Transforms(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n) self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)
def test_rtf_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
channel, freq, _ = specgram.shape
rtf = torch.rand(freq, channel, dtype=self.complex_dtype, device=self.device)
psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
reference_channel = 0
self._assert_consistency_complex(T.RTFMVDR(), specgram, rtf, psd_n, reference_channel)
def test_souden_mvdr(self): def test_souden_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4) tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
......
...@@ -1892,35 +1892,39 @@ def mvdr_weights_rtf( ...@@ -1892,35 +1892,39 @@ def mvdr_weights_rtf(
.. properties:: Autograd TorchScript .. properties:: Autograd TorchScript
Given the relative transfer function (RTF) matrix or the steering vector of target speech :math:`\bm{v}`,
the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
:math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
.. math:: .. math::
\textbf{w}_{\text{MVDR}}(f) = \textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}} \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)} {{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
where :math:`\bm{v}` is the RTF or the steering vector.
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation. where :math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
Args: Args:
rtf (Tensor): The complex-valued RTF vector of target speech. rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor of dimension `(..., freq, channel)`. Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued covariance matrix of noise. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)` Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or Tensor, optional): Indicate the reference channel. reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represent the reference channel index. If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot. is one-hot.
If a non-None value is given, the MVDR weights will be normalized by ``rtf[..., reference_channel].conj()`` diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``None``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``) (Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
(Default: ``1e-7``) It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``) eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns: Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel). torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
""" """
if diagonal_loading: if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps) psd_n = _tik_reg(psd_n, reg=diag_eps)
# numerator = psd_n.inv() @ stv # numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel) numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv # denominator = stv^H @ psd_n.inv() @ stv
......
...@@ -24,6 +24,7 @@ from ._transforms import ( ...@@ -24,6 +24,7 @@ from ._transforms import (
RNNTLoss, RNNTLoss,
PSD, PSD,
MVDR, MVDR,
RTFMVDR,
SoudenMVDR, SoudenMVDR,
) )
...@@ -46,6 +47,7 @@ __all__ = [ ...@@ -46,6 +47,7 @@ __all__ = [
"PSD", "PSD",
"PitchShift", "PitchShift",
"RNNTLoss", "RNNTLoss",
"RTFMVDR",
"Resample", "Resample",
"SlidingWindowCmn", "SlidingWindowCmn",
"SoudenMVDR", "SoudenMVDR",
......
...@@ -2091,6 +2091,70 @@ class MVDR(torch.nn.Module): ...@@ -2091,6 +2091,70 @@ class MVDR(torch.nn.Module):
return specgram_enhanced return specgram_enhanced
class RTFMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the relative transfer function (RTF) matrix
or the steering vector of target speech :math:`\bm{v}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin,
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
The beamforming weight is computed by:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
"""
def forward(
self,
specgram: Tensor,
rtf: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
"""
Args:
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`
rtf (torch.Tensor): The complex-valued RTF vector of target speech.
Tensor with dimensions `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
w_mvdr = F.mvdr_weights_rtf(rtf, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
return spectrum_enhanced
class SoudenMVDR(torch.nn.Module): class SoudenMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`]. based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment