"git@developer.sourcefind.cn:OpenDAS/dcu_env_check.git" did not exist on "dd24dfc04b8eecb6b67a949ca99d99fbdd6d1af3"
Commit da1e83cc authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add diagonal_loading optional to rtf_power (#2369)

Summary:
When computing the MVDR beamforming weights using the power iteration method, the PSD matrix of noise can be applied with diagonal loading to improve the robustness. This is also applicable to computing the RTF matrix (See https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py#L614 as an example). This also aligns with current `torchaudio.transforms.MVDR` module to keep the consistency.

This PR adds the `diagonal_loading` argument with `True` as default value to `torchaudio.functional.rtf_power`.

Pull Request resolved: https://github.com/pytorch/audio/pull/2369

Reviewed By: carolineechen

Differential Revision: D36204130

Pulled By: nateanl

fbshipit-source-id: 93a58d5c2107841a16c4e32f0c16ab0d6b2d9420
parent aed5eb88
...@@ -55,7 +55,14 @@ def rtf_evd_numpy(psd): ...@@ -55,7 +55,14 @@ def rtf_evd_numpy(psd):
return rtf return rtf
def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter): def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading=True, diag_eps=1e-7, eps=1e-8):
if diagonal_loading:
channel = psd_s.shape[-1]
eye = np.eye(channel)
trace = np.matrix.trace(psd_n, axis1=1, axis2=2)
epsilon = trace.real[..., None, None] * diag_eps + eps
diag = epsilon * eye[..., :, :]
psd_n = psd_n + diag
phi = np.linalg.solve(psd_n, psd_s) phi = np.linalg.solve(psd_n, psd_s)
if isinstance(reference_channel, int): if isinstance(reference_channel, int):
rtf = phi[..., reference_channel] rtf = phi[..., reference_channel]
......
...@@ -333,25 +333,25 @@ class Autograd(TestBaseMixin): ...@@ -333,25 +333,25 @@ class Autograd(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
[ [
(1,), (1, True),
(3,), (3, False),
] ]
) )
def test_rtf_power(self, n_iter): def test_rtf_power(self, n_iter, diagonal_loading):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
channel = 4 channel = 4
n_fft_bin = 5 n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat) psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat) psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter)) self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter, diagonal_loading))
@parameterized.expand( @parameterized.expand(
[ [
(1,), (1, True),
(3,), (3, False),
] ]
) )
def test_rtf_power_with_tensor(self, n_iter): def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
channel = 4 channel = 4
n_fft_bin = 5 n_fft_bin = 5
...@@ -359,7 +359,7 @@ class Autograd(TestBaseMixin): ...@@ -359,7 +359,7 @@ class Autograd(TestBaseMixin):
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat) psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel) reference_channel = torch.zeros(channel)
reference_channel[0].fill_(1) reference_channel[0].fill_(1)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter)) self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading))
def test_apply_beamforming(self): def test_apply_beamforming(self):
torch.random.manual_seed(2434) torch.random.manual_seed(2434)
......
...@@ -740,12 +740,12 @@ class Functional(TestBaseMixin): ...@@ -740,12 +740,12 @@ class Functional(TestBaseMixin):
@parameterized.expand( @parameterized.expand(
[ [
(1,), (1, True),
(2,), (2, False),
(3,), (3, True),
] ]
) )
def test_rtf_power(self, n_iter): def test_rtf_power(self, n_iter, diagonal_loading):
"""Verify ``F.rtf_power`` method by numpy implementation. """Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`) Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
an integer indicating the reference channel, and an integer for number of iterations, ``F.rtf_power`` an integer indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
...@@ -757,23 +757,24 @@ class Functional(TestBaseMixin): ...@@ -757,23 +757,24 @@ class Functional(TestBaseMixin):
reference_channel = 0 reference_channel = 0
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter) rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading)
rtf_audio = F.rtf_power( rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device), torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device), torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel, reference_channel,
n_iter, n_iter,
diagonal_loading=diagonal_loading,
) )
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio) self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
@parameterized.expand( @parameterized.expand(
[ [
(1,), (1, True),
(2,), (2, False),
(3,), (3, True),
] ]
) )
def test_rtf_power_with_tensor(self, n_iter): def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
"""Verify ``F.rtf_power`` method by numpy implementation. """Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`) Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
a one-hot Tensor indicating the reference channel, and an integer for number of iterations, ``F.rtf_power`` a one-hot Tensor indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
...@@ -786,12 +787,13 @@ class Functional(TestBaseMixin): ...@@ -786,12 +787,13 @@ class Functional(TestBaseMixin):
reference_channel[0] = 1 reference_channel[0] = 1
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter) rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter, diagonal_loading)
rtf_audio = F.rtf_power( rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device), torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device), torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
torch.tensor(reference_channel, dtype=self.dtype, device=self.device), torch.tensor(reference_channel, dtype=self.dtype, device=self.device),
n_iter, n_iter,
diagonal_loading=diagonal_loading,
) )
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio) self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
......
...@@ -700,32 +700,38 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -700,32 +700,38 @@ class Functional(TempDirMixin, TestBaseMixin):
@parameterized.expand( @parameterized.expand(
[ [
(1,), (1, True),
(3,), (3, False),
] ]
) )
def test_rtf_power(self, n_iter): def test_rtf_power(self, n_iter, diagonal_loading):
channel = 4 channel = 4
n_fft_bin = 10 n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype) psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype) psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = 0 reference_channel = 0
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter)) diag_eps = 1e-7
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps)
)
@parameterized.expand( @parameterized.expand(
[ [
(1,), (1, True),
(3,), (3, False),
] ]
) )
def test_rtf_power_with_tensor(self, n_iter): def test_rtf_power_with_tensor(self, n_iter, diagonal_loading):
channel = 4 channel = 4
n_fft_bin = 10 n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype) psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype) psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(channel) reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1) reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter)) diag_eps = 1e-7
self._assert_consistency_complex(
F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter, diagonal_loading, diag_eps)
)
def test_apply_beamforming(self): def test_apply_beamforming(self):
num_channels = 4 num_channels = 4
......
...@@ -1961,7 +1961,14 @@ def rtf_evd(psd_s: Tensor) -> Tensor: ...@@ -1961,7 +1961,14 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
return rtf return rtf
def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor], n_iter: int = 3) -> Tensor: def rtf_power(
psd_s: Tensor,
psd_n: Tensor,
reference_channel: Union[int, Tensor],
n_iter: int = 3,
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method. r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
.. devices:: CPU CUDA .. devices:: CPU CUDA
...@@ -1969,21 +1976,27 @@ def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor ...@@ -1969,21 +1976,27 @@ def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor
.. properties:: Autograd TorchScript .. properties:: Autograd TorchScript
Args: Args:
psd_s (Tensor): The complex-valued covariance matrix of target speech. psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)` Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (Tensor): The complex-valued covariance matrix of noise. psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor of dimension `(..., freq, channel, channel)` Tensor with dimensions `(..., freq, channel, channel)`.
reference_channel (int or Tensor): Indicate the reference channel. reference_channel (int or torch.Tensor): Specifies the reference channel.
If the dtype is ``int``, it represent the reference channel index. If the dtype is ``int``, it represents the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension If the dtype is ``torch.Tensor``, its shape is `(..., channel)`, where the ``channel`` dimension
is one-hot. is one-hot.
n_iter (int): number of iterations in power method. (Default: ``3``) diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
Returns: Returns:
Tensor: the estimated complex-valued RTF of target speech torch.Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)` Tensor of dimension `(..., freq, channel)`.
""" """
assert n_iter > 0, "The number of iteration must be greater than 0." assert n_iter > 0, "The number of iteration must be greater than 0."
# Apply diagonal loading to psd_n to improve robustness.
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps)
# phi is regarded as the first iteration # phi is regarded as the first iteration
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
if torch.jit.isinstance(reference_channel, int): if torch.jit.isinstance(reference_channel, int):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment