Commit 3566ffc5 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add mvdr_weights_rtf to torchaudio.functional (#2229)

Summary:
This PR adds ``mvdr_weights_rtf`` method to ``torchaudio.functional``.
It computes the MVDR weight matrix based on the solution that applies relative transfer function (RTF). See [the paper](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.725.673&rep=rep1&type=pdf) for the reference.
The input arguments are the complex-valued RTF Tensor of the target speech, power spectral density (PSD) matrix of noise, int or one-hot Tensor to indicate the reference channel, respectively.

Pull Request resolved: https://github.com/pytorch/audio/pull/2229

Reviewed By: mthrok

Differential Revision: D34474119

Pulled By: nateanl

fbshipit-source-id: 2d6f62cd0858f29ed6e4e03c23dcc11c816204e2
parent 5d06a369
......@@ -251,6 +251,11 @@ mvdr_weights_souden
.. autofunction:: mvdr_weights_souden
mvdr_weights_rtf
----------------
.. autofunction:: mvdr_weights_rtf
:hidden:`Loss`
~~~~~~~~~~~~~~
......
......@@ -29,3 +29,21 @@ def mvdr_weights_souden_numpy(psd_s, psd_n, reference_channel, diag_eps=1e-7, ep
else:
beamform_weights = np.einsum("...c,...c->...", ws, reference_channel[..., None, None, :])
return beamform_weights
def mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel, diag_eps=1e-7, eps=1e-8):
channel = rtf.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
numerator = np.linalg.solve(psd_n, np.expand_dims(rtf, -1)).squeeze(-1)
denominator = np.einsum("...d,...d->...", rtf.conj(), numerator)
beamform_weights = numerator / (np.expand_dims(denominator.real, -1) + eps)
if isinstance(reference_channel, int):
scale = rtf[..., reference_channel].conj()
else:
scale = np.einsum("...c,...c->...", rtf.conj(), reference_channel[..., None, :])
beamform_weights = beamform_weights * scale[..., None]
return beamform_weights
......@@ -283,6 +283,26 @@ class Autograd(TestBaseMixin):
reference_channel[0].fill_(1)
self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))
def test_mvdr_weights_rtf(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 10
rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, 0))
def test_mvdr_weights_rtf_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 10
rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
class AutogradFloat32(TestBaseMixin):
def assert_grad(
......
......@@ -341,3 +341,27 @@ class TestFunctional(common_utils.TorchaudioTestCase):
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_batch_consistency(F.mvdr_weights_souden, (psd_noise, psd_speech, reference_channel))
def test_mvdr_weights_rtf(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 129
rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
kwargs = {
"reference_channel": 0,
}
func = partial(F.mvdr_weights_rtf, **kwargs)
self.assert_batch_consistency(func, (rtf, psd_noise))
def test_mvdr_weights_rtf_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 129
rtf = torch.rand(batch_size, n_fft_bin, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_batch_consistency(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
......@@ -671,6 +671,59 @@ class Functional(TestBaseMixin):
rtol=1e-6,
)
def test_mvdr_weights_rtf(self):
"""Verify ``F.mvdr_weights_rtf`` method by numpy implementation.
Given the relative transfer function (RTF) of target speech (Tensor of dimension `(..., freq, channel)`),
the PSD matrix of noise (Tensor of dimension `(..., freq, channel, channel)`), and an integer
indicating the reference channel as inputs, ``F.mvdr_weights_rtf`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_rtf_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = 0
rtf = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = beamform_utils.mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_rtf(
torch.tensor(rtf, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel,
)
self.assertEqual(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
beamform_weights_audio,
atol=1e-3,
rtol=1e-6,
)
def test_mvdr_weights_rtf_with_tensor(self):
"""Verify ``F.mvdr_weights_rtf`` method by numpy implementation.
Given the relative transfer function (RTF) of target speech (Tensor of dimension `(..., freq, channel)`),
the PSD matrix of noise (Tensor of dimension `(..., freq, channel, channel)`), and a one-hot Tensor
indicating the reference channel as inputs, ``F.mvdr_weights_rtf`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_rtf_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = np.zeros(channel)
reference_channel[0] = 1
rtf = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = beamform_utils.mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_rtf(
torch.tensor(rtf, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
torch.tensor(reference_channel, dtype=self.dtype, device=self.device),
)
self.assertEqual(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
beamform_weights_audio,
atol=1e-3,
rtol=1e-6,
)
class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
......
......@@ -664,6 +664,33 @@ class Functional(TempDirMixin, TestBaseMixin):
F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel, diagonal_loading, diag_eps, eps)
)
def test_mvdr_weights_rtf(self):
channel = 4
n_fft_bin = 10
diagonal_loading = True
diag_eps = 1e-7
eps = 1e-8
rtf = torch.rand(n_fft_bin, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = 0
self._assert_consistency_complex(
F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel, diagonal_loading, diag_eps, eps)
)
def test_mvdr_weights_rtf_with_tensor(self):
channel = 4
n_fft_bin = 10
diagonal_loading = True
diag_eps = 1e-7
eps = 1e-8
rtf = torch.rand(n_fft_bin, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(
F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel, diagonal_loading, diag_eps, eps)
)
class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
......
......@@ -48,6 +48,7 @@ from .functional import (
rnnt_loss,
psd,
mvdr_weights_souden,
mvdr_weights_rtf,
)
__all__ = [
......@@ -98,4 +99,5 @@ __all__ = [
"rnnt_loss",
"psd",
"mvdr_weights_souden",
"mvdr_weights_rtf",
]
......@@ -39,6 +39,7 @@ __all__ = [
"rnnt_loss",
"psd",
"mvdr_weights_souden",
"mvdr_weights_rtf",
]
......@@ -1764,3 +1765,63 @@ def mvdr_weights_souden(
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
return beamform_weights
def mvdr_weights_rtf(
rtf: Tensor,
psd_n: Tensor,
reference_channel: Optional[Union[int, Tensor]] = None,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
based on the relative transfer function (RTF) and power spectral density (PSD) matrix of noise.
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}}
{{\bm{v}^{\mathsf{H}}}(f){\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bm{v}}(f)}
where :math:`\bm{v}` is the RTF or the steering vector.
:math:`(.)^{\mathsf{H}}` denotes the Hermitian Conjugate operation.
Args:
rtf (Tensor): The complex-valued RTF vector of target speech.
Tensor of dimension `(..., freq, channel)`.
psd_n (torch.Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor, optional): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
is one-hot.
If a non-None value is given, the MVDR weights will be normalized by ``rtf[..., reference_channel].conj()``
(Default: ``None``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, rtf.unsqueeze(-1)).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator = torch.einsum("...d,...d->...", [rtf.conj(), numerator])
beamform_weights = numerator / (denominator.real.unsqueeze(-1) + eps)
# normalize the numerator
if reference_channel is not None:
if torch.jit.isinstance(reference_channel, int):
scale = rtf[..., reference_channel].conj()
elif torch.jit.isinstance(reference_channel, Tensor):
reference_channel = reference_channel.to(psd_n.dtype)
scale = torch.einsum("...c,...c->...", [rtf.conj(), reference_channel[..., None, :]])
else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
beamform_weights = beamform_weights * scale[..., None]
return beamform_weights
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment