Commit 5d06a369 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add mvdr_weights_souden to torchaudio.functional (#2228)

Summary:
This PR adds ``mvdr_weights_souden`` method to ``torchaudio.functional``.
It computes the MVDR weight matrix based on the solution proposed by [``Souden et, al.``](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.725.673&rep=rep1&type=pdf).
The input arguments are the complex-valued power spectral density (PSD) matrix of the target speech, PSD matrix of noise, int or one-hot Tensor to indicate the reference channel, respectively.

Pull Request resolved: https://github.com/pytorch/audio/pull/2228

Reviewed By: mthrok

Differential Revision: D34474018

Pulled By: nateanl

fbshipit-source-id: 725df812f8f6e6cc81cc37e8c3cb0da2ab3b74fb
parent 07bd1aa3
...@@ -246,6 +246,11 @@ psd ...@@ -246,6 +246,11 @@ psd
.. autofunction:: psd .. autofunction:: psd
mvdr_weights_souden
-------------------
.. autofunction:: mvdr_weights_souden
:hidden:`Loss` :hidden:`Loss`
~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~
......
...@@ -251,3 +251,13 @@ ...@@ -251,3 +251,13 @@
year={2017}, year={2017},
publisher={IEEE} publisher={IEEE}
} }
@article{capon1969high,
title={High-resolution frequency-wavenumber spectrum analysis},
author={Capon, Jack},
journal={Proceedings of the IEEE},
volume={57},
number={8},
pages={1408--1418},
year={1969},
publisher={IEEE}
}
...@@ -12,3 +12,20 @@ def psd_numpy(specgram, mask=None, normalize=True, eps=1e-10): ...@@ -12,3 +12,20 @@ def psd_numpy(specgram, mask=None, normalize=True, eps=1e-10):
psd = psd * mask_normmalized[..., None, None] psd = psd * mask_normmalized[..., None, None]
psd = psd.sum(axis=-3) psd = psd.sum(axis=-3)
return psd return psd
def mvdr_weights_souden_numpy(psd_s, psd_n, reference_channel, diag_eps=1e-7, eps=1e-8):
channel = psd_s.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
numerator = np.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
numerator_trace = np.matrix.trace(numerator, axis1=1, axis2=2)
ws = numerator / (numerator_trace[..., None, None] + eps)
if isinstance(reference_channel, int):
beamform_weights = ws[..., :, reference_channel]
else:
beamform_weights = np.einsum("...c,...c->...", ws, reference_channel[..., None, None, :])
return beamform_weights
...@@ -265,6 +265,24 @@ class Autograd(TestBaseMixin): ...@@ -265,6 +265,24 @@ class Autograd(TestBaseMixin):
mask = None mask = None
self.assert_grad(F.psd, (specgram, mask)) self.assert_grad(F.psd, (specgram, mask))
def test_mvdr_weights_souden(self):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, 0))
def test_mvdr_weights_souden_with_tensor(self):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel)
reference_channel[0].fill_(1)
self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))
class AutogradFloat32(TestBaseMixin): class AutogradFloat32(TestBaseMixin):
def assert_grad( def assert_grad(
......
...@@ -317,3 +317,27 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -317,3 +317,27 @@ class TestFunctional(common_utils.TorchaudioTestCase):
specgram = specgram.view(batch_size, channel, n_fft_bin, specgram.size(-1)) specgram = specgram.view(batch_size, channel, n_fft_bin, specgram.size(-1))
mask = torch.rand(batch_size, n_fft_bin, specgram.size(-1)) mask = torch.rand(batch_size, n_fft_bin, specgram.size(-1))
self.assert_batch_consistency(F.psd, (specgram, mask)) self.assert_batch_consistency(F.psd, (specgram, mask))
def test_mvdr_weights_souden(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
kwargs = {
"reference_channel": 0,
}
func = partial(F.mvdr_weights_souden, **kwargs)
self.assert_batch_consistency(func, (psd_noise, psd_speech))
def test_mvdr_weights_souden_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_batch_consistency(F.mvdr_weights_souden, (psd_noise, psd_speech, reference_channel))
...@@ -620,6 +620,57 @@ class Functional(TestBaseMixin): ...@@ -620,6 +620,57 @@ class Functional(TestBaseMixin):
) )
self.assertEqual(torch.tensor(psd, dtype=self.complex_dtype, device=self.device), psd_audio) self.assertEqual(torch.tensor(psd, dtype=self.complex_dtype, device=self.device), psd_audio)
def test_mvdr_weights_souden(self):
"""Verify ``F.mvdr_weights_souden`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
and an integer indicating the reference channel, ``F.mvdr_weights_souden`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_souden_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = 0
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = beamform_utils.mvdr_weights_souden_numpy(psd_s, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_souden(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel,
)
self.assertEqual(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
beamform_weights_audio,
atol=1e-3,
rtol=1e-6,
)
def test_mvdr_weights_souden_with_tensor(self):
"""Verify ``F.mvdr_weights_souden`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
and a one-hot Tensor indicating the reference channel, ``F.mvdr_weights_souden`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_souden_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = np.zeros(channel)
reference_channel[0] = 1
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = beamform_utils.mvdr_weights_souden_numpy(psd_s, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_souden(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
torch.tensor(reference_channel, dtype=self.dtype, device=self.device),
)
self.assertEqual(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
beamform_weights_audio,
atol=1e-3,
rtol=1e-6,
)
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self): def test_melscale_fbanks_no_warning_high_n_freq(self):
......
...@@ -638,6 +638,32 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -638,6 +638,32 @@ class Functional(TempDirMixin, TestBaseMixin):
mask = torch.rand(batch_size, n_fft_bin, frame, device=self.device) mask = torch.rand(batch_size, n_fft_bin, frame, device=self.device)
self._assert_consistency_complex(F.psd, (specgram, mask, normalize, eps)) self._assert_consistency_complex(F.psd, (specgram, mask, normalize, eps))
def test_mvdr_weights_souden(self):
channel = 4
n_fft_bin = 10
diagonal_loading = True
diag_eps = 1e-7
eps = 1e-8
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self._assert_consistency_complex(
F.mvdr_weights_souden, (psd_speech, psd_noise, 0, diagonal_loading, diag_eps, eps)
)
def test_mvdr_weights_souden_with_tensor(self):
channel = 4
n_fft_bin = 10
diagonal_loading = True
diag_eps = 1e-7
eps = 1e-8
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(
F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel, diagonal_loading, diag_eps, eps)
)
class FunctionalFloat32Only(TestBaseMixin): class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self): def test_rnnt_loss(self):
......
...@@ -47,6 +47,7 @@ from .functional import ( ...@@ -47,6 +47,7 @@ from .functional import (
pitch_shift, pitch_shift,
rnnt_loss, rnnt_loss,
psd, psd,
mvdr_weights_souden,
) )
__all__ = [ __all__ = [
...@@ -96,4 +97,5 @@ __all__ = [ ...@@ -96,4 +97,5 @@ __all__ = [
"pitch_shift", "pitch_shift",
"rnnt_loss", "rnnt_loss",
"psd", "psd",
"mvdr_weights_souden",
] ]
...@@ -4,7 +4,7 @@ import io ...@@ -4,7 +4,7 @@ import io
import math import math
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torchaudio import torchaudio
...@@ -38,6 +38,7 @@ __all__ = [ ...@@ -38,6 +38,7 @@ __all__ = [
"pitch_shift", "pitch_shift",
"rnnt_loss", "rnnt_loss",
"psd", "psd",
"mvdr_weights_souden",
] ]
...@@ -1669,3 +1670,97 @@ def psd( ...@@ -1669,3 +1670,97 @@ def psd(
psd = psd.sum(dim=-3) psd = psd.sum(dim=-3)
return psd return psd
def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
Returns:
Tensor: trace of the input Tensor
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
Returns:
Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye[..., :, :]
return mat
def mvdr_weights_souden(
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
by the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}`
are the power spectral density (PSD) matrices of speech and noise, respectively.
:math:`\bf{u}` is a one-hot vector that represents the reference channel.
Args:
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
if torch.jit.isinstance(reference_channel, int):
beamform_weights = ws[..., :, reference_channel]
elif torch.jit.isinstance(reference_channel, Tensor):
reference_channel = reference_channel.to(psd_n.dtype)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_weights = torch.einsum("...c,...c->...", [ws, reference_channel[..., None, None, :]])
else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
return beamform_weights
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment