Commit ea74813d authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add rtf_power method to torchaudio.functional (#2231)

Summary:
This PR adds ``rtf_power`` method to ``torchaudio.functional``.
The method computes the relative transfer function (RTF) or the steering vector by [the power iteration method](https://onlinelibrary.wiley.com/doi/abs/10.1002/zamm.19290090206).
[This paper](https://arxiv.org/pdf/2011.15003.pdf) describes the power iteration method in English.
The input arguments are the complex-valued power spectral density (PSD) matrix of the target speech, PSD matrix of noise, int or one-hot Tensor to indicate the reference channel, number of iterations, respectively.

Pull Request resolved: https://github.com/pytorch/audio/pull/2231

Reviewed By: mthrok

Differential Revision: D34474503

Pulled By: nateanl

fbshipit-source-id: 47011427ec4373f808755f0e8eff1efca57655eb
parent 8c1db721
...@@ -261,6 +261,11 @@ rtf_evd ...@@ -261,6 +261,11 @@ rtf_evd
.. autofunction:: rtf_evd .. autofunction:: rtf_evd
rtf_power
---------
.. autofunction:: rtf_power
:hidden:`Loss` :hidden:`Loss`
~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~
......
...@@ -53,3 +53,20 @@ def rtf_evd_numpy(psd): ...@@ -53,3 +53,20 @@ def rtf_evd_numpy(psd):
_, v = np.linalg.eigh(psd) _, v = np.linalg.eigh(psd)
rtf = v[..., -1] rtf = v[..., -1]
return rtf return rtf
def rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter):
phi = np.linalg.solve(psd_n, psd_s)
if isinstance(reference_channel, int):
rtf = phi[..., reference_channel]
else:
rtf = phi @ reference_channel
rtf = np.expand_dims(rtf, -1)
if n_iter >= 2:
for _ in range(n_iter - 2):
rtf = phi @ rtf
rtf = psd_s @ rtf
else:
rtf = psd_n @ rtf
rtf = rtf.squeeze(-1)
return rtf
...@@ -303,6 +303,36 @@ class Autograd(TestBaseMixin): ...@@ -303,6 +303,36 @@ class Autograd(TestBaseMixin):
reference_channel[..., 0].fill_(1) reference_channel[..., 0].fill_(1)
self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel)) self.assert_grad(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
@parameterized.expand(
[
(1,),
(3,),
]
)
def test_rtf_power(self, n_iter):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, 0, n_iter))
@parameterized.expand(
[
(1,),
(3,),
]
)
def test_rtf_power_with_tensor(self, n_iter):
torch.random.manual_seed(2434)
channel = 4
n_fft_bin = 5
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(channel)
reference_channel[0].fill_(1)
self.assert_grad(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
class AutogradFloat32(TestBaseMixin): class AutogradFloat32(TestBaseMixin):
def assert_grad( def assert_grad(
......
...@@ -374,3 +374,44 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -374,3 +374,44 @@ class TestFunctional(common_utils.TorchaudioTestCase):
spectrum = torch.rand(batch_size, n_fft_bin, channel, dtype=torch.cfloat) spectrum = torch.rand(batch_size, n_fft_bin, channel, dtype=torch.cfloat)
psd = torch.einsum("...c,...d->...cd", spectrum, spectrum.conj()) psd = torch.einsum("...c,...d->...cd", spectrum, spectrum.conj())
self.assert_batch_consistency(F.rtf_evd, (psd,)) self.assert_batch_consistency(F.rtf_evd, (psd,))
@parameterized.expand(
[
(1,),
(3,),
]
)
def test_rtf_power(self, n_iter):
torch.random.manual_seed(2434)
channel = 4
batch_size = 2
n_fft_bin = 10
psd_speech = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
kwargs = {
"reference_channel": 0,
"n_iter": n_iter,
}
func = partial(F.rtf_power, **kwargs)
self.assert_batch_consistency(func, (psd_speech, psd_noise))
@parameterized.expand(
[
(1,),
(3,),
]
)
def test_rtf_power_with_tensor(self, n_iter):
torch.random.manual_seed(2434)
channel = 4
batch_size = 2
n_fft_bin = 10
psd_speech = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
psd_noise = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=torch.cfloat)
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
kwargs = {
"n_iter": n_iter,
}
func = partial(F.rtf_power, **kwargs)
self.assert_batch_consistency(func, (psd_speech, psd_noise, reference_channel))
...@@ -738,6 +738,63 @@ class Functional(TestBaseMixin): ...@@ -738,6 +738,63 @@ class Functional(TestBaseMixin):
rtf_audio = F.rtf_evd(torch.tensor(psd, dtype=self.complex_dtype, device=self.device)) rtf_audio = F.rtf_evd(torch.tensor(psd, dtype=self.complex_dtype, device=self.device))
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio) self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_rtf_power(self, n_iter):
"""Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
an integer indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
outputs the relative transfer function (RTF) (Tensor of dimension `(..., freq, channel)`),
which should be identical to the output of ``rtf_power_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = 0
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter)
rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
reference_channel,
n_iter,
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
@parameterized.expand(
[
(1,),
(2,),
(3,),
]
)
def test_rtf_power_with_tensor(self, n_iter):
"""Verify ``F.rtf_power`` method by numpy implementation.
Given the PSD matrices of target speech and noise (Tensor of dimension `(..., freq, channel, channel`)
a one-hot Tensor indicating the reference channel, and an integer for number of iterations, ``F.rtf_power``
outputs the relative transfer function (RTF) (Tensor of dimension `(..., freq, channel)`),
which should be identical to the output of ``rtf_power_numpy``.
"""
n_fft_bin = 10
channel = 4
reference_channel = np.zeros(channel)
reference_channel[0] = 1
psd_s = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
psd_n = np.random.random((n_fft_bin, channel, channel)) + np.random.random((n_fft_bin, channel, channel)) * 1j
rtf = beamform_utils.rtf_power_numpy(psd_s, psd_n, reference_channel, n_iter)
rtf_audio = F.rtf_power(
torch.tensor(psd_s, dtype=self.complex_dtype, device=self.device),
torch.tensor(psd_n, dtype=self.complex_dtype, device=self.device),
torch.tensor(reference_channel, dtype=self.dtype, device=self.device),
n_iter,
)
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
class FunctionalCPUOnly(TestBaseMixin): class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self): def test_melscale_fbanks_no_warning_high_n_freq(self):
......
...@@ -698,6 +698,35 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -698,6 +698,35 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype) tensor = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
self._assert_consistency_complex(F.rtf_evd, (tensor,)) self._assert_consistency_complex(F.rtf_evd, (tensor,))
@parameterized.expand(
[
(1,),
(3,),
]
)
def test_rtf_power(self, n_iter):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = 0
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
@parameterized.expand(
[
(1,),
(3,),
]
)
def test_rtf_power_with_tensor(self, n_iter):
channel = 4
n_fft_bin = 10
psd_speech = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
psd_noise = torch.rand(n_fft_bin, channel, channel, dtype=self.complex_dtype)
reference_channel = torch.zeros(channel)
reference_channel[..., 0].fill_(1)
self._assert_consistency_complex(F.rtf_power, (psd_speech, psd_noise, reference_channel, n_iter))
class FunctionalFloat32Only(TestBaseMixin): class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self): def test_rnnt_loss(self):
......
...@@ -50,6 +50,7 @@ from .functional import ( ...@@ -50,6 +50,7 @@ from .functional import (
mvdr_weights_souden, mvdr_weights_souden,
mvdr_weights_rtf, mvdr_weights_rtf,
rtf_evd, rtf_evd,
rtf_power,
) )
__all__ = [ __all__ = [
...@@ -102,4 +103,5 @@ __all__ = [ ...@@ -102,4 +103,5 @@ __all__ = [
"mvdr_weights_souden", "mvdr_weights_souden",
"mvdr_weights_rtf", "mvdr_weights_rtf",
"rtf_evd", "rtf_evd",
"rtf_power",
] ]
...@@ -41,6 +41,7 @@ __all__ = [ ...@@ -41,6 +41,7 @@ __all__ = [
"mvdr_weights_souden", "mvdr_weights_souden",
"mvdr_weights_rtf", "mvdr_weights_rtf",
"rtf_evd", "rtf_evd",
"rtf_power",
] ]
...@@ -1842,3 +1843,46 @@ def rtf_evd(psd_s: Tensor) -> Tensor: ...@@ -1842,3 +1843,46 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
_, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order _, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
rtf = v[..., -1] # choose the eigenvector with max eigenvalue rtf = v[..., -1] # choose the eigenvector with max eigenvalue
return rtf return rtf
def rtf_power(psd_s: Tensor, psd_n: Tensor, reference_channel: Union[int, Tensor], n_iter: int = 3) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by the power method.
Args:
psd_s (Tensor): The complex-valued covariance matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
psd_n (Tensor): The complex-valued covariance matrix of noise.
Tensor of dimension `(..., freq, channel, channel)`
reference_channel (int or Tensor): Indicate the reference channel.
If the dtype is ``int``, it represent the reference channel index.
If the dtype is ``Tensor``, the dimension is `(..., channel)`, where the ``channel`` dimension
is one-hot.
n_iter (int): number of iterations in power method. (Default: ``3``)
Returns:
Tensor: the estimated complex-valued RTF of target speech
Tensor of dimension `(..., freq, channel)`
"""
assert n_iter > 0, "The number of iteration must be greater than 0."
# phi is regarded as the first iteration
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
if torch.jit.isinstance(reference_channel, int):
rtf = phi[..., reference_channel]
elif torch.jit.isinstance(reference_channel, Tensor):
reference_channel = reference_channel.to(psd_n.dtype)
rtf = torch.einsum("...c,...c->...", [phi, reference_channel[..., None, None, :]])
else:
raise TypeError(f"Expected 'int' or 'Tensor' for reference_channel. Found: {type(reference_channel)}.")
rtf = rtf.unsqueeze(-1) # (..., freq, channel, 1)
if n_iter >= 2:
# The number of iterations in the for loop is `n_iter - 2`
# because the `phi` above and `torch.matmul(psd_s, rtf)` are regarded as
# two iterations.
for _ in range(n_iter - 2):
rtf = torch.matmul(phi, rtf)
rtf = torch.matmul(psd_s, rtf)
else:
# if there is only one iteration, the rtf is the psd_s[..., referenc_channel]
# which is psd_n @ phi @ ref_channel
rtf = torch.matmul(psd_n, rtf)
return rtf.squeeze(-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment