Commit 38e530d7 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Add assertion checks to multi-channel functions (#2401)

Summary:
- The multi-channel functions only support complex-valued tensors for spectrogram and PSD matrices.
- The mask can be real-valued or complex-valued, hence there is no explicit assertion for mask.
- The shape of input Tensors need to be verified before the computation. For example, the shape of PSD matrix must be `(..., freq, channel, channel)`, the shape of the mask must be `(..., freq, time)`, etc.
- The autograd unittest of `apply_beamforming` has wrong dimensions for beamform_weights detected by the assertion check. FIx it in this PR.

Pull Request resolved: https://github.com/pytorch/audio/pull/2401

Reviewed By: carolineechen

Differential Revision: D36597689

Pulled By: nateanl

fbshipit-source-id: 6ad1adebe3726851cc1d865650bdf177a98985f6
parent af9cab3b
......@@ -370,7 +370,7 @@ class Autograd(TestBaseMixin):
x = get_whitenoise(sample_rate=sr, duration=0.05, n_channels=batch_size * num_channels)
specgram = get_spectrogram(x, n_fft=n_fft, hop_length=100)
specgram = specgram.view(batch_size, num_channels, n_fft_bin, specgram.size(-1))
beamform_weights = torch.rand(n_fft_bin, num_channels, dtype=torch.cfloat)
beamform_weights = torch.rand(batch_size, n_fft_bin, num_channels, dtype=torch.cfloat)
self.assert_grad(F.apply_beamforming, (beamform_weights, specgram))
......
......@@ -1751,9 +1751,7 @@ def psd(
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`.
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: ``None``)
Tensor with dimensions `(..., freq, time)`. (Default: ``None``)
normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
......@@ -1767,6 +1765,10 @@ def psd(
psd = torch.einsum("...ct,...et->...tce", [specgram, specgram.conj()])
if mask is not None:
assert (
mask.shape[:-1] == specgram.shape[:-2] and mask.shape[-1] == specgram.shape[-1]
), "The dimensions of mask except the channel dimension should be the same as specgram."
f"Found {mask.shape} for mask and {specgram.shape} for specgram."
# Normalized mask along time dimension:
if normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + eps)
......@@ -1817,6 +1819,31 @@ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.T
return mat
def _assert_psd_matrices(psd_s: torch.Tensor, psd_n: torch.Tensor) -> None:
"""Assertion checks of the PSD matrices of target speech and noise.
Args:
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
"""
assert (
psd_s.ndim >= 3 and psd_n.ndim >= 3
), "Expected at least 3D Tensor (..., freq, channel, channel) for psd_s and psd_n."
"Found {psd_s.shape} for psd_s and {psd_n.shape} for psd_n."
assert (
psd_s.is_complex() and psd_n.is_complex()
), "The type of psd_s and psd_n must be ``torch.cfloat`` or ``torch.cdouble``."
f"Found {psd_s.dtype} for psd_s and {psd_n.dtype} for psd_n."
assert (
psd_s.shape == psd_n.shape
), f"The dimensions of psd_s and psd_n should be the same. Found {psd_s.shape} and {psd_n.shape}."
assert (
psd_s.shape[-1] == psd_s.shape[-2]
), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}."
def mvdr_weights_souden(
psd_s: Tensor,
psd_n: Tensor,
......@@ -1861,6 +1888,8 @@ def mvdr_weights_souden(
Returns:
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
_assert_psd_matrices(psd_s, psd_n)
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps)
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
......@@ -1924,6 +1953,20 @@ def mvdr_weights_rtf(
Returns:
torch.Tensor: The complex-valued MVDR beamforming weight matrix with dimensions `(..., freq, channel)`.
"""
assert rtf.ndim >= 2, f"Expected at least 2D Tensor (..., freq, channel) for rtf. Found {rtf.shape}."
assert psd_n.ndim >= 3, f"Expected at least 3D Tensor (..., freq, channel, channel) for psd_n. Found {psd_n.shape}."
assert (
rtf.is_complex() and psd_n.is_complex()
), "The type of rtf and psd_n must be ``torch.cfloat`` or ``torch.cdouble``."
f"Found {rtf.dtype} for rtf and {psd_n.dtype} for psd_n."
assert (
rtf.shape == psd_n.shape[:-1]
), "The dimensions of rtf and the dimensions withou the last dimension of psd_n should be the same."
f"Found {rtf.shape} for rtf and {psd_n.shape} for psd_n."
assert (
psd_n.shape[-1] == psd_n.shape[-2]
), f"The last two dimensions of psd_n should be the same. Found {psd_n.shape}."
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps)
# numerator = psd_n.inv() @ stv
......@@ -1961,6 +2004,10 @@ def rtf_evd(psd_s: Tensor) -> Tensor:
Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`
"""
assert psd_s.is_complex(), f"The type of psd_s must be ``torch.cfloat`` or ``torch.cdouble``. Found {psd_s.dtype}."
assert (
psd_s.shape[-1] == psd_s.shape[-2]
), f"The last two dimensions of psd_s should be the same. Found {psd_s.shape}."
_, v = torch.linalg.eigh(psd_s) # v is sorted along with eigenvalues in ascending order
rtf = v[..., -1] # choose the eigenvector with max eigenvalue
return rtf
......@@ -1998,7 +2045,9 @@ def rtf_power(
torch.Tensor: The estimated complex-valued RTF of target speech.
Tensor of dimension `(..., freq, channel)`.
"""
_assert_psd_matrices(psd_s, psd_n)
assert n_iter > 0, "The number of iteration must be greater than 0."
# Apply diagonal loading to psd_n to improve robustness.
if diagonal_loading:
psd_n = _tik_reg(psd_n, reg=diag_eps)
......@@ -2048,6 +2097,16 @@ def apply_beamforming(beamform_weights: Tensor, specgram: Tensor) -> Tensor:
Tensor: The single-channel complex-valued enhanced spectrum.
Tensor of dimension `(..., freq, time)`
"""
assert (
beamform_weights.shape[:-2] == specgram.shape[:-3]
), "The dimensions except the last two dimensions of beamform_weights should be the same "
"as the dimensions except the last three dimensions of specgram."
f"Found {beamform_weights.shape} for beamform_weights and {specgram.shape} for specgram."
assert (
beamform_weights.is_complex() and specgram.is_complex()
), "The type of beamform_weights and specgram must be ``torch.cfloat`` or ``torch.cdouble``."
f"Found {beamform_weights.dtype} for beamform_weights and {specgram.dtype} for specgram."
# (..., freq, channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_weights.conj(), specgram])
return specgram_enhanced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment