Commit 86fe4fa7 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add rtf_evd method to torchaudio.functional (#2230)

Summary:
This PR adds `rtf_evd` method to `torchaudio.functional`.
The method computes the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
The input argument is the power spectral density (PSD) matrix of the target speech.

Pull Request resolved: https://github.com/pytorch/audio/pull/2230

Reviewed By: mthrok

Differential Revision: D34474188

Pulled By: nateanl

fbshipit-source-id: 888df4b187608ed3c2b7271b34d2231cdabb0134
parent 3566ffc5
......@@ -256,6 +256,11 @@ mvdr_weights_rtf
.. autofunction:: mvdr_weights_rtf
rtf_evd
-------
.. autofunction:: rtf_evd
:hidden:`Loss`
~~~~~~~~~~~~~~
......
......@@ -47,3 +47,9 @@ def mvdr_weights_rtf_numpy(rtf, psd_n, reference_channel, diag_eps=1e-7, eps=1e-
scale = np.einsum("...c,...c->...", rtf.conj(), reference_channel[..., None, :])
beamform_weights = beamform_weights * scale[..., None]
return beamform_weights
def rtf_evd_numpy(psd):
_, v = np.linalg.eigh(psd)
rtf = v[..., -1]
return rtf
......@@ -365,3 +365,12 @@ class TestFunctional(common_utils.TorchaudioTestCase):
reference_channel = torch.zeros(batch_size, channel)
reference_channel[..., 0].fill_(1)
self.assert_batch_consistency(F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel))
def test_rtf_evd(self):
torch.random.manual_seed(2434)
batch_size = 2
channel = 4
n_fft_bin = 5
spectrum = torch.rand(batch_size, n_fft_bin, channel, dtype=torch.cfloat)
psd = torch.einsum("...c,...d->...cd", spectrum, spectrum.conj())
self.assert_batch_consistency(F.rtf_evd, (psd,))
......@@ -724,6 +724,20 @@ class Functional(TestBaseMixin):
rtol=1e-6,
)
def test_rtf_evd(self):
"""Verify ``F.rtf_evd`` method by the numpy implementation.
Given the multi-channel complex-valued spectrum, we compute the PSD matrix as the input,
``F.rtf_evd`` outputs the relative transfer function (RTF) (Tensor of dimension `(..., freq, channel)`),
which should be identical to the output of ``rtf_evd_numpy``.
"""
n_fft_bin = 10
channel = 4
specgram = np.random.random((n_fft_bin, channel)) + np.random.random((n_fft_bin, channel)) * 1j
psd = np.einsum("fc,fd->fcd", specgram.conj(), specgram)
rtf = beamform_utils.rtf_evd_numpy(psd)
rtf_audio = F.rtf_evd(torch.tensor(psd, dtype=self.complex_dtype, device=self.device))
self.assertEqual(torch.tensor(rtf, dtype=self.complex_dtype, device=self.device), rtf_audio)
class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
......
......@@ -691,6 +691,13 @@ class Functional(TempDirMixin, TestBaseMixin):
F.mvdr_weights_rtf, (rtf, psd_noise, reference_channel, diagonal_loading, diag_eps, eps)
)
def test_rtf_evd(self):
batch_size = 2
channel = 4
n_fft_bin = 129
tensor = torch.rand(batch_size, n_fft_bin, channel, channel, dtype=self.complex_dtype)
self._assert_consistency_complex(F.rtf_evd, (tensor,))
class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
......
......@@ -49,6 +49,7 @@ from .functional import (
psd,
mvdr_weights_souden,
mvdr_weights_rtf,
rtf_evd,
)
__all__ = [
......@@ -100,4 +101,5 @@ __all__ = [
"psd",
"mvdr_weights_souden",
"mvdr_weights_rtf",
"rtf_evd",
]
......@@ -40,6 +40,7 @@ __all__ = [
"psd",
"mvdr_weights_souden",
"mvdr_weights_rtf",
"rtf_evd",
]
......@@ -1825,3 +1826,19 @@ def mvdr_weights_rtf(
beamform_weights = beamform_weights * scale[..., None]
return beamform_weights
def rtf_evd(psd_s: Tensor) -> Tensor:
r"""Estimate the relative transfer function (RTF) or the steering vector by eigenvalue decomposition.
Args:
psd_s (Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor of dimension `(..., freq, channel, channel)`
Returns:
Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`
"""
_, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
rtf = v[..., -1] # choose the eigenvector with max eigenvalue
return rtf
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment