Commit aed5eb88 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add SoudenMVDR module (#2367)

Summary:
Add a new design of MVDR module.
The `SoudenMVDR` module supports the method proposed by [Souden et, al.](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.725.673&rep=rep1&type=pdf).
The input arguments are:
- multi-channel spectrum.
- PSD matrix of target speech.
- PSD matrix of noise.
- reference channel in the microphone array.
- diagonal_loading option to enable or disable diagonal loading in matrix inverse computation.
- diag_eps for computing the inverse of the matrix.
- eps for computing the beamforming weight.

The output of the module is the single-channel complex-valued spectrum for the enhanced speech.

Pull Request resolved: https://github.com/pytorch/audio/pull/2367

Reviewed By: hwangjeff

Differential Revision: D36198015

Pulled By: nateanl

fbshipit-source-id: 4027f4752a84aaef730ef3ea8c625e801cc35527
parent 54d2d04f
...@@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -197,6 +197,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward .. automethod:: forward
:hidden:`SoudenMVDR`
--------------------
.. autoclass:: SoudenMVDR
.. automethod:: forward
References References
~~~~~~~~~~ ~~~~~~~~~~
......
...@@ -303,6 +303,16 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -303,6 +303,16 @@ class AutogradTestMixin(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:]) mask_n = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask_s, mask_n]) self.assert_grad(transform, [spectrogram, mask_s, mask_n])
def test_souden_mvdr(self):
transform = T.SoudenMVDR()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
specgram = get_spectrogram(waveform, n_fft=400)
channel, freq, _ = specgram.shape
psd_s = torch.rand(freq, channel, channel, dtype=torch.cfloat)
psd_n = torch.rand(freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
self.assert_grad(transform, [specgram, psd_s, psd_n, reference_channel])
class AutogradTestFloat32(TestBaseMixin): class AutogradTestFloat32(TestBaseMixin):
def assert_grad( def assert_grad(
......
...@@ -219,3 +219,22 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -219,3 +219,22 @@ class TestTransforms(common_utils.TorchaudioTestCase):
computed = transform(specgram, mask_s, mask_n) computed = transform(specgram, mask_s, mask_n)
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_souden_mvdr(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
batch_size, channel, freq, time = 3, 2, specgram.shape[-2], specgram.shape[-1]
specgram = specgram.reshape(batch_size, channel, freq, time)
psd_s = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
psd_n = torch.rand(batch_size, freq, channel, channel, dtype=torch.cfloat)
reference_channel = 0
transform = T.SoudenMVDR()
# Single then transform then batch
expected = [transform(specgram[i], psd_s[i], psd_n[i], reference_channel) for i in range(batch_size)]
expected = torch.stack(expected)
# Batch then transform
computed = transform(specgram, psd_s, psd_n, reference_channel)
self.assertEqual(computed, expected)
...@@ -175,6 +175,15 @@ class Transforms(TestBaseMixin): ...@@ -175,6 +175,15 @@ class Transforms(TestBaseMixin):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n) self._assert_consistency_complex(T.MVDR(solution=solution, online=online), spectrogram, mask_s, mask_n)
def test_souden_mvdr(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
specgram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
channel, freq, _ = specgram.shape
psd_s = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
psd_n = torch.rand(freq, channel, channel, dtype=self.complex_dtype, device=self.device)
reference_channel = 0
self._assert_consistency_complex(T.SoudenMVDR(), specgram, psd_s, psd_n, reference_channel)
class TransformsFloat32Only(TestBaseMixin): class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self): def test_rnnt_loss(self):
......
...@@ -1831,34 +1831,37 @@ def mvdr_weights_souden( ...@@ -1831,34 +1831,37 @@ def mvdr_weights_souden(
.. properties:: Autograd TorchScript .. properties:: Autograd TorchScript
Given the power spectral density (PSD) matrix of target speech :math:`\bf{\Phi}_{\textbf{SS}}`,
the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and a one-hot vector that represents the
reference channel :math:`\bf{u}`, the method computes the MVDR beamforming weight martrix
:math:`\textbf{w}_{\text{MVDR}}`. The formula is defined as:
.. math:: .. math::
\textbf{w}_{\text{MVDR}}(f) = \textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)} \frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u} {\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}`
are the power spectral density (PSD) matrices of speech and noise, respectively.
:math:`\bf{u}` is a one-hot vector that represents the reference channel.
Args: Args:
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech. psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)` Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)` Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or Tensor): Indicate the reference channel. reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represent the reference channel index. If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot. is one-hot.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``) (Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
(Default: ``1e-7``) It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``) eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns: Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel). torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
""" """
if diagonal_loading: if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps) psd_n = _tik_reg(psd_n, reg=diag_eps)
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C) # ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps) ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
......
...@@ -24,6 +24,7 @@ from ._transforms import ( ...@@ -24,6 +24,7 @@ from ._transforms import (
RNNTLoss, RNNTLoss,
PSD, PSD,
MVDR, MVDR,
SoudenMVDR,
) )
...@@ -47,6 +48,7 @@ __all__ = [ ...@@ -47,6 +48,7 @@ __all__ = [
"RNNTLoss", "RNNTLoss",
"Resample", "Resample",
"SlidingWindowCmn", "SlidingWindowCmn",
"SoudenMVDR",
"SpectralCentroid", "SpectralCentroid",
"Spectrogram", "Spectrogram",
"TimeMasking", "TimeMasking",
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
import math import math
import warnings import warnings
from typing import Callable, Optional from typing import Callable, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -2089,3 +2089,66 @@ class MVDR(torch.nn.Module): ...@@ -2089,3 +2089,66 @@ class MVDR(torch.nn.Module):
specgram_enhanced.to(dtype) specgram_enhanced.to(dtype)
return specgram_enhanced return specgram_enhanced
class SoudenMVDR(torch.nn.Module):
r"""Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) module
based on the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
.. devices:: CPU CUDA
.. properties:: Autograd TorchScript
Given the multi-channel complex-valued spectrum :math:`\textbf{Y}`, the power spectral density (PSD) matrix
of target speech :math:`\bf{\Phi}_{\textbf{SS}}`, the PSD matrix of noise :math:`\bf{\Phi}_{\textbf{NN}}`, and
a one-hot vector that represents the reference channel :math:`\bf{u}`, the module computes the single-channel
complex-valued spectrum of the enhanced speech :math:`\hat{\textbf{S}}`. The formula is defined as:
.. math::
\hat{\textbf{S}}(f) = \textbf{w}_{\text{bf}}(f)^{\mathsf{H}} \textbf{Y}(f)
where :math:`\textbf{w}_{\text{bf}}(f)` is the MVDR beamforming weight for the :math:`f`-th frequency bin.
The beamforming weight is computed by:
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
"""
def forward(
self,
specgram: Tensor,
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
"""
Args:
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`.
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
w_mvdr = F.mvdr_weights_souden(psd_s, psd_n, reference_channel, diagonal_loading, diag_eps, eps)
spectrum_enhanced = F.apply_beamforming(w_mvdr, specgram)
return spectrum_enhanced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment