Commit 5d06a369 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add mvdr_weights_souden to torchaudio.functional (#2228)

Summary:
This PR adds ``mvdr_weights_souden`` method to ``torchaudio.functional``.
It computes the MVDR weight matrix based on the solution proposed by [``Souden et, al.``](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.725.673&rep=rep1&type=pdf).
The input arguments are the complex-valued power spectral density (PSD) matrix of the target speech, PSD matrix of noise, int or one-hot Tensor to indicate the reference channel, respectively.

Pull Request resolved: https://github.com/pytorch/audio/pull/2228

Reviewed By: mthrok

Differential Revision: D34474018

Pulled By: nateanl

fbshipit-source-id: 725df812f8f6e6cc81cc37e8c3cb0da2ab3b74fb
parent 07bd1aa3
......@@ -246,6 +246,11 @@ psd
.. autofunction:: psd
mvdr_weights_souden
-------------------
.. autofunction:: mvdr_weights_souden
:hidden:`Loss`
~~~~~~~~~~~~~~
......
......@@ -251,3 +251,13 @@
year={2017},
publisher={IEEE}
}
@article{capon1969high,
title={High-resolution frequency-wavenumber spectrum analysis},
author={Capon, Jack},
journal={Proceedings of the IEEE},
volume={57},
number={8},
pages={1408--1418},
year={1969},
publisher={IEEE}
}
......@@ -12,3 +12,20 @@ def psd_numpy(specgram, mask=None, normalize=True, eps=1e-10):
psd = psd * mask_normmalized[..., None, None]
psd = psd.sum(axis=-3)
return psd
def mvdr_weights_souden_numpy(psd_s, psd_n, reference_channel, diag_eps=1e-7, eps=1e-8):
channel = psd_s.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
numerator = np.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
numerator_trace = np.matrix.trace(numerator, axis1=1, axis2=2)
ws = numerator / (numerator_trace[..., None, None] + eps)
if isinstance(reference_channel, int):
beamform_weights = ws[..., :, reference_channel]
else:
beamform_weights = np.einsum("...c,...c->...", ws, reference_channel[..., None, None, :])
return beamform_weights
......@@ -265,6 +265,24 @@ class Autograd(TestBaseMixin):
mask = None
self.assert_grad(F.psd, (specgram, mask))
def test_mvdr_weights_souden(self):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, 0))
def test_mvdr_weights_souden_with_tensor(self):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel)
reference_channel[0].fill_(1)
self.assert_grad(F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel))
class AutogradFloat32(TestBaseMixin):
def assert_grad(
......
......@@ -317,3 +317,27 @@ class TestFunctional(common_utils.TorchaudioTestCase):
specgram = specgram.view(batch_size, channel, n_fft_bin, specgram.size(-1))
mask = torch.rand(batch_size, n_fft_bin, specgram.size(-1))
self.assert_batch_consistency(F.psd, (specgram, mask))
def test_mvdr_weights_souden(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
kwargs = {
"reference_channel": 0,
}
func = partial(F.mvdr_weights_souden, **kwargs)
self.assert_batch_consistency(func, (psd_noise, psd_speech))
def test_mvdr_weights_souden_with_tensor(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_batch_consistency(F.mvdr_weights_souden, (psd_noise, psd_speech, reference_channel))
......@@ -620,6 +620,57 @@ class Functional(TestBaseMixin):
)
self.assertEqual(torch.tensor(psd, dtype=self.complex_dtype, device=self.device), psd_audio)
def test_mvdr_weights_souden(self):
"""Verify ``F.mvdr_weights_souden`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
and an integer indicating the reference channel, ``F.mvdr_weights_souden`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_souden_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = 0
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = beamform_utils.mvdr_weights_souden_numpy(psd_s, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_souden(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel,
)
self.assertEqual(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
beamform_weights_audio,
atol=1e-3,
rtol=1e-6,
)
def test_mvdr_weights_souden_with_tensor(self):
"""Verify ``F.mvdr_weights_souden`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
and a one-hot Tensor indicating the reference channel, ``F.mvdr_weights_souden`` outputs the mvdr weights
(Tensor of dimension `(..., freq, channel)`), which should be close to the output of
``mvdr_weights_souden_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = np.zeros(channel)
reference_channel[0] = 1
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
beamform_weights = beamform_utils.mvdr_weights_souden_numpy(psd_s, psd_n, reference_channel)
beamform_weights_audio = F.mvdr_weights_souden(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
torch.tensor(reference_channel, dtype=self.dtype, device=self.device),
)
self.assertEqual(
torch.tensor(beamform_weights, dtype=self.complex_dtype, device=self.device),
beamform_weights_audio,
atol=1e-3,
rtol=1e-6,
)
class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
......
......@@ -638,6 +638,32 @@ class Functional(TempDirMixin, TestBaseMixin):
mask = torch.rand(batch_size, n_fft_bin, frame, device=self.device)
self._assert_consistency_complex(F.psd, (specgram, mask, normalize, eps))
def test_mvdr_weights_souden(self):
channel = 4
n_fft_bin = 10
diagonal_loading = True
diag_eps = 1e-7
eps = 1e-8
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self._assert_consistency_complex(
F.mvdr_weights_souden, (psd_speech, psd_noise, 0, diagonal_loading, diag_eps, eps)
)
def test_mvdr_weights_souden_with_tensor(self):
channel = 4
n_fft_bin = 10
diagonal_loading = True
diag_eps = 1e-7
eps = 1e-8
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(
F.mvdr_weights_souden, (psd_speech, psd_noise, reference_channel, diagonal_loading, diag_eps, eps)
)
class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
......
......@@ -47,6 +47,7 @@ from .functional import (
pitch_shift,
rnnt_loss,
psd,
mvdr_weights_souden,
)
__all__ = [
......@@ -96,4 +97,5 @@ __all__ = [
"pitch_shift",
"rnnt_loss",
"psd",
"mvdr_weights_souden",
]
......@@ -4,7 +4,7 @@ import io
import math
import warnings
from collections.abc import Sequence
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torchaudio
......@@ -38,6 +38,7 @@ __all__ = [
"pitch_shift",
"rnnt_loss",
"psd",
"mvdr_weights_souden",
]
......@@ -1669,3 +1670,97 @@ def psd(
psd = psd.sum(dim=-3)
return psd
def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
Returns:
Tensor: trace of the input Tensor
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
Returns:
Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
epsilon = _compute_mat_trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye[..., :, :]
return mat
def mvdr_weights_souden(
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> Tensor:
r"""Compute the Minimum Variance Distortionless Response (*MVDR* [:footcite:`capon1969high`]) beamforming weights
by the method proposed by *Souden et, al.* [:footcite:`souden2009optimal`].
.. math::
\textbf{w}_{\text{MVDR}}(f) =
\frac{{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f){\bf{\Phi}_{\textbf{SS}}}}(f)}
{\text{Trace}({{{\bf{\Phi}_{\textbf{NN}}^{-1}}(f) \bf{\Phi}_{\textbf{SS}}}(f))}}\bm{u}
where :math:`\bf{\Phi}_{\textbf{SS}}` and :math:`\bf{\Phi}_{\textbf{NN}}`
are the power spectral density (PSD) matrices of speech and noise, respectively.
:math:`\bf{u}` is a one-hot vector that represents the reference channel.
Args:
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
is one-hot.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading
(Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-8``)
Returns:
Tensor: The complex-valued MVDR beamforming weight matrix of dimension (..., freq, channel).
"""
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps, eps=eps)
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_compute_mat_trace(numerator)[..., None, None] + eps)
if torch.jit.isinstance(reference_channel, int):
beamform_weights = ws[..., :, reference_channel]
elif torch.jit.isinstance(reference_channel, Tensor):
reference_channel = reference_channel.to(psd_n.dtype)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_weights = torch.einsum("...c,...c->...", [ws, reference_channel[..., None, None, :]])
else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
return beamform_weights
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment