Commit 07bd1aa3 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add psd method to torchaudio.functional (#2227)

Summary:
This PR adds ``psd`` method to ``torchaudio.functional``.
It computes the power spectral density (PSD) matrix of the complex-valued spectrum.
The method also supports normalization of Time-Frequency mask.

Pull Request resolved: https://github.com/pytorch/audio/pull/2227

Reviewed By: mthrok

Differential Revision: D34473908

Pulled By: nateanl

fbshipit-source-id: c1cfc584085d77881b35d41d76d39b26fca1dda9
parent 34b53ee7
......@@ -238,6 +238,14 @@ treble_biquad
.. autofunction:: spectral_centroid
:hidden:`Multi-channel`
~~~~~~~~~~~~~~~~~~~~~~~
psd
---
.. autofunction:: psd
:hidden:`Loss`
~~~~~~~~~~~~~~
......
import numpy as np
def psd_numpy(specgram, mask=None, normalize=True, eps=1e-10):
specgram_transposed = np.swapaxes(specgram, 0, 1)
psd = np.einsum("...ct,...et->...tce", specgram_transposed, specgram_transposed.conj())
if mask is not None:
if normalize:
mask_normmalized = mask / (mask.sum(axis=-1, keepdims=True) + eps)
else:
mask_normmalized = mask
psd = psd * mask_normmalized[..., None, None]
psd = psd.sum(axis=-3)
return psd
......@@ -250,6 +250,21 @@ class Autograd(TestBaseMixin):
Q = torch.tensor(Q)
self.assert_grad(F.bandreject_biquad, (x, sr, central_freq, Q))
@parameterized.expand(
[
(True,),
(False,),
]
)
def test_psd(self, use_mask):
torch.random.manual_seed(2434)
specgram = torch.rand(4, 10, 5, dtype=torch.cfloat)
if use_mask:
mask = torch.rand(10, 5)
else:
mask = None
self.assert_grad(F.psd, (specgram, mask))
class AutogradFloat32(TestBaseMixin):
def assert_grad(
......
......@@ -294,3 +294,26 @@ class TestFunctional(common_utils.TorchaudioTestCase):
a = torch.rand(self.batch_size, 3)
b = torch.rand(self.batch_size, 3)
self.assert_batch_consistency(F.filtfilt, inputs=(x, a, b))
def test_psd(self):
batch_size = 2
channel = 3
sample_rate = 44100
n_fft = 400
n_fft_bin = 201
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=batch_size * channel)
specgram = common_utils.get_spectrogram(waveform, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, channel, n_fft_bin, specgram.size(-1))
self.assert_batch_consistency(F.psd, (specgram,))
def test_psd_with_mask(self):
batch_size = 2
channel = 3
sample_rate = 44100
n_fft = 400
n_fft_bin = 201
waveform = common_utils.get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=batch_size * channel)
specgram = common_utils.get_spectrogram(waveform, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, channel, n_fft_bin, specgram.size(-1))
mask = torch.rand(batch_size, n_fft_bin, specgram.size(-1))
self.assert_batch_consistency(F.psd, (specgram, mask))
......@@ -14,6 +14,7 @@ from torchaudio_unittest.common_utils import (
nested_params,
get_whitenoise,
rnnt_utils,
beamform_utils,
)
......@@ -582,6 +583,43 @@ class Functional(TestBaseMixin):
ref_costs, ref_gradients = rnnt_utils.compute_with_numpy_transducer(data=data)
self._test_costs_and_gradients(data=data, ref_costs=ref_costs, ref_gradients=ref_gradients)
def test_psd(self):
"""Verify the ``F.psd`` method by the numpy implementation.
Given the multi-channel complex-valued spectrum as the input,
the output of ``F.psd`` should be identical to that of ``psd_numpy``.
"""
channel = 4
n_fft_bin = 10
frame = 5
specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j
psd = beamform_utils.psd_numpy(specgram)
psd_audio = F.psd(torch.tensor(specgram, dtype=self.complex_dtype, device=self.device))
self.assertEqual(torch.tensor(psd, dtype=self.complex_dtype, device=self.device), psd_audio)
@parameterized.expand(
[
(True,),
(False,),
]
)
def test_psd_with_mask(self, normalize: bool):
"""Verify the ``F.psd`` method by the numpy implementation.
Given the multi-channel complex-valued spectrum and the single-channel real-valued mask
as the inputs, the output of ``F.psd`` should be identical to that of ``psd_numpy``.
"""
channel = 4
n_fft_bin = 10
frame = 5
specgram = np.random.random((channel, n_fft_bin, frame)) + np.random.random((channel, n_fft_bin, frame)) * 1j
mask = np.random.random((n_fft_bin, frame))
psd = beamform_utils.psd_numpy(specgram, mask, normalize)
psd_audio = F.psd(
torch.tensor(specgram, dtype=self.complex_dtype, device=self.device),
torch.tensor(mask, dtype=self.dtype, device=self.device),
normalize=normalize,
)
self.assertEqual(torch.tensor(psd, dtype=self.complex_dtype, device=self.device), psd_audio)
class FunctionalCPUOnly(TestBaseMixin):
def test_melscale_fbanks_no_warning_high_n_freq(self):
......
......@@ -617,6 +617,27 @@ class Functional(TempDirMixin, TestBaseMixin):
)[..., None]
self._assert_consistency_complex(F.phase_vocoder, (tensor, rate, phase_advance))
def test_psd(self):
batch_size = 2
channel = 4
n_fft_bin = 10
frame = 10
normalize = True
eps = 1e-10
tensor = torch.rand(batch_size, channel, n_fft_bin, frame, dtype=self.complex_dtype)
self._assert_consistency_complex(F.psd, (tensor, None, normalize, eps))
def test_psd_with_mask(self):
batch_size = 2
channel = 4
n_fft_bin = 10
frame = 10
normalize = True
eps = 1e-10
specgram = torch.rand(batch_size, channel, n_fft_bin, frame, dtype=self.complex_dtype)
mask = torch.rand(batch_size, n_fft_bin, frame, device=self.device)
self._assert_consistency_complex(F.psd, (specgram, mask, normalize, eps))
class FunctionalFloat32Only(TestBaseMixin):
def test_rnnt_loss(self):
......
......@@ -46,6 +46,7 @@ from .functional import (
edit_distance,
pitch_shift,
rnnt_loss,
psd,
)
__all__ = [
......@@ -94,4 +95,5 @@ __all__ = [
"edit_distance",
"pitch_shift",
"rnnt_loss",
"psd",
]
......@@ -37,6 +37,7 @@ __all__ = [
"edit_distance",
"pitch_shift",
"rnnt_loss",
"psd",
]
......@@ -1631,3 +1632,40 @@ def rnnt_loss(
return costs.sum()
return costs
def psd(
specgram: Tensor,
mask: Optional[Tensor] = None,
normalize: bool = True,
eps: float = 1e-10,
) -> Tensor:
"""Compute cross-channel power spectral density (PSD) matrix.
Args:
specgram (Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)`
mask (Tensor or None, optional): Real-valued time-frequency mask
for normalization. Tensor of dimension `(..., freq, time)`
(Default: ``None``)
normalize (bool, optional): whether to normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-10``)
Returns:
Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)`
"""
specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
# outer product:
# (..., ch_1, time) x (..., ch_2, time) -> (..., time, ch_1, ch_2)
psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])
if mask is not None:
# Normalized mask along time dimension:
if normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
psd = psd * mask[..., None, None]
psd = psd.sum(dim=-3)
return psd
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment