Commit da1e83cc authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add diagonal_loading optional to rtf_power (#2369)

Summary:
When computing the MVDR beamforming weights using the power iteration method, the PSD matrix of noise can be applied with diagonal loading to improve the robustness. This is also applicable to computing the RTF matrix (See https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py#L614 as an example). This also aligns with current `torchaudio.transforms.MVDR` module to keep the consistency.

This PR adds the `diagonal_loading` argument with `True` as default value to `torchaudio.functional.rtf_power`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2369

Reviewed By: carolineechen

Differential Revision: D36204130

Pulled By: nateanl

fbshipit-source-id: 93a58d5c2107841a16c4e32f0c16ab0d6b2d9420
parent aed5eb88
......@@ -55,7 +55,14 @@ def rtf_evd_numpy(psd):
return rtf
def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter):
def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading=True, diag_eps=1e-7, eps=1e-8):
if diagonal_loading:
channel = psd_s.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
phi = np.linalg.solve(psd_n, psd_s)
if isinstance(reference_channel, int):
rtf = phi[..., reference_channel]
......
......@@ -333,25 +333,25 @@ class Autograd(TestBaseMixin):
@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter))
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter, diagonal_loading))
@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power_with_tensor(self, n_iter):
def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
......@@ -359,7 +359,7 @@ class Autograd(TestBaseMixin):
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel)
reference_channel[0].fill_(1)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading))
def test_apply_beamforming(self):
torch.random.manual_seed(2434)
......
......@@ -740,12 +740,12 @@ class Functional(TestBaseMixin):
@parameterized.expand(
[
(1,),
(2,),
(3,),
(1, True),
(2, False),
(3, True),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
"""Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
an integer indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
......@@ -757,23 +757,24 @@ class Functional(TestBaseMixin):
reference_channel = 0
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter)
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading)
rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel,
n_iter,
diagonal_loading=diagonal_loading,
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
@parameterized.expand(
[
(1,),
(2,),
(3,),
(1, True),
(2, False),
(3, True),
]
)
def test_rtf_power_with_tensor(self, n_iter):
def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
"""Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
a one-hot Tensor indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
......@@ -786,12 +787,13 @@ class Functional(TestBaseMixin):
reference_channel[0] = 1
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter)
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading)
rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
torch.tensor(reference_channel, dtype=self.dtype, device=self.device),
n_iter,
diagonal_loading=diagonal_loading,
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
......
......@@ -700,32 +700,38 @@ class Functional(TempDirMixin, TestBaseMixin):
@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power(self, n_iter):
def test_rtf_power(self, n_iter, diagonal_loading):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = 0
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
diag_eps = 1e-7
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps)
)
@parameterized.expand(
[
(1,),
(3,),
(1, True),
(3, False),
]
)
def test_rtf_power_with_tensor(self, n_iter):
def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
diag_eps = 1e-7
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps)
)
def test_apply_beamforming(self):
num_channels = 4
......
......@@ -1961,7 +1961,14 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
return rtf
def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor], n_iter: int = 3) -> Tensor:
def rtf_power(
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
n_iter: int = 3,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
.. devices:: CPU CUDA
......@@ -1969,21 +1976,27 @@ def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor
.. properties:: Autograd TorchScript
Args:
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot.
n_iter (int): number of iterations in power method. (Default: ``3``)
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
Returns:
Tensor: the estimated complex-valued RTF of target speech
Tensor of dimension `(..., freq, channel)`
torch.Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`.
"""
assert n_iter > 0, "The number of iteration must be greater than 0."
# Apply diagonal loading to psd_n to improve robustness.
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps)
# phi is regarded as the first iteration
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
if torch.jit.isinstance(reference_channel, int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment