Unverified Commit ac97ad82 authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Move MVDR and PSD modules to transforms (#1771)

parent 88ca1e05
...@@ -95,4 +95,42 @@ ...@@ -95,4 +95,42 @@
pages={4779--4783}, pages={4779--4783},
year={2018}, year={2018},
organization={IEEE} organization={IEEE}
} }
\ No newline at end of file @inproceedings{souden2009optimal,
title={On optimal frequency-domain multichannel linear filtering for noise reduction},
author={Souden, Mehrez and Benesty, Jacob and Affes, Sofiene},
booktitle={IEEE Transactions on audio, speech, and language processing},
volume={18},
number={2},
pages={260--276},
year={2009},
publisher={IEEE}
}
@inproceedings{higuchi2016robust,
title={Robust MVDR beamforming using time-frequency masks for online/offline ASR in noise},
author={Higuchi, Takuya and Ito, Nobutaka and Yoshioka, Takuya and Nakatani, Tomohiro},
booktitle={2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={5210--5214},
year={2016},
organization={IEEE}
}
@article{mises1929praktische,
title={Praktische Verfahren der Gleichungsaufl{\"o}sung.},
author={Mises, RV and Pollaczek-Geiringer, Hilda},
journal={ZAMM-Journal of Applied Mathematics and Mechanics/Zeitschrift f{\"u}r Angewandte Mathematik und Mechanik},
volume={9},
number={1},
pages={58--77},
year={1929},
publisher={Wiley Online Library}
}
@article{higuchi2017online,
title={Online MVDR beamformer based on complex Gaussian mixture model with spatial prior for noise robust ASR},
author={Higuchi, Takuya and Ito, Nobutaka and Araki, Shoko and Yoshioka, Takuya and Delcroix, Marc and Nakatani, Tomohiro},
journal={IEEE/ACM Transactions on Audio, Speech, and Language Processing},
volume={25},
number={4},
pages={780--793},
year={2017},
publisher={IEEE}
}
...@@ -188,6 +188,23 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -188,6 +188,23 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward .. automethod:: forward
:hidden:`Multi-channel`
~~~~~~~~~~~~~~~~~~~~~~~
:hidden:`PSD`
-------------
.. autoclass:: PSD
.. automethod:: forward
:hidden:`MVDR`
--------------
.. autoclass:: MVDR
.. automethod:: forward
References References
~~~~~~~~~~ ~~~~~~~~~~
......
"""Implementation of MVDR Beamforming Module
Based on https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/beamformer.py
We provide three solutions of MVDR beamforming. One is based on reference channel selection:
Souden, Mehrez, Jacob Benesty, and Sofiene Affes.
"On optimal frequency-domain multichannel linear filtering for noise reduction."
IEEE Transactions on audio, speech, and language processing 18.2 (2009): 260-276.
The other two solutions are based on the steering vector. We apply either eigenvalue decomposition
or the power method to get the steering vector from the PSD matrices.
For eigenvalue decomposistion method, please refer:
Higuchi, Takuya, et al. "Robust MVDR beamforming using time-frequency masks for online/offline ASR in noise."
2016 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2016.
For power method, please refer:
Mises, R. V., and Hilda Pollaczek‐Geiringer.
"Praktische Verfahren der Gleichungsauflösung."
ZAMM‐Journal of Applied Mathematics and Mechanics/Zeitschrift für Angewandte Mathematik und Mechanik 9.1 (1929): 58-77.
For online streaming audio, we provide a recursive method to update PSD matrices based on:
Higuchi, Takuya, et al.
"Online MVDR beamformer based on complex Gaussian mixture model with spatial prior for noise robust ASR."
IEEE/ACM Transactions on Audio, Speech, and Language Processing 25.4 (2017): 780-793.
"""
from typing import Optional
import torch
def mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension (..., channel, channel)
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
Returns:
torch.Tensor: trace of the input Tensor
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2],\
"The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
class PSD(torch.nn.Module):
r"""Compute cross-channel power spectral density (PSD) matrix.
Args:
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks (Default: ``False``)
normalize (bool, optional): whether normalize the mask along the time dimension
eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-15
"""
def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15):
super().__init__()
self.multi_mask = multi_mask
self.normalize = normalize
self.eps = eps
def forward(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None):
"""
Args:
X (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension (..., channel, freq, time)
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension (..., freq, time) if multi_mask is ``False`` or
of dimension (..., channel, freq, time) if multi_mask is ``True``
Returns:
torch.Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension (..., freq, channel, channel)
"""
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
psd_X = torch.einsum("...cft,...eft->...ftce", [X, X.conj()])
if mask is None:
psd = psd_X
else:
if self.multi_mask:
# Averaging mask along channel dimension
mask = mask.mean(dim=-3) # (..., freq, time)
# Normalized mask along time dimension:
if self.normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + self.eps)
psd = psd_X * mask.unsqueeze(-1).unsqueeze(-1)
psd = psd.sum(dim=-3)
return psd
class MVDR(torch.nn.Module):
"""MVDR module that performs MVDR beamforming with Time-Frequency masks.
Args:
ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``)
solution (str, optional): the solution to get MVDR weight.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks (Default: ``False``)
diag_loading (bool, optional): whether apply diagonal loading on the psd matrix of noise
(Default: ``True``)
diag_eps (float, optional): the coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
online (bool, optional): whether to update the mvdr vector based on the previous psd matrices.
(Default: ``False``)
Note:
If you use ``stv_evd`` solution, the gradient of the same input may not be identical if the
eigenvalues of the PSD matrix are not distinct (i.e. some eigenvalues are close or identical).
"""
def __init__(
self,
ref_channel: int = 0,
solution: str = "ref_channel",
multi_mask: bool = False,
diag_loading: bool = True,
diag_eps: float = 1e-7,
online: bool = False,
):
super().__init__()
assert solution in ["ref_channel", "stv_evd", "stv_power"],\
"Unknown solution provided. Must be one of [``ref_channel``, ``stv_evd``, ``stv_power``]."
self.ref_channel = ref_channel
self.solution = solution
self.multi_mask = multi_mask
self.diag_loading = diag_loading
self.diag_eps = diag_eps
self.online = online
self.psd = PSD(multi_mask)
psd_s: torch.Tensor = torch.zeros(1)
psd_n: torch.Tensor = torch.zeros(1)
mask_sum_s: torch.Tensor = torch.zeros(1)
mask_sum_n: torch.Tensor = torch.zeros(1)
self.register_buffer('psd_s', psd_s)
self.register_buffer('psd_n', psd_n)
self.register_buffer('mask_sum_s', mask_sum_s)
self.register_buffer('mask_sum_n', mask_sum_n)
def _get_updated_mvdr_vector(
self,
psd_s: torch.Tensor,
psd_n: torch.Tensor,
mask_s: torch.Tensor,
mask_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = 'ref_channel',
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
r"""Recursively update the MVDR beamforming vector.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
mask_s (torch.Tensor): T-F mask of target speech
mask_n (torch.Tensor): T-F mask of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8)
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if self.multi_mask:
# Averaging mask along channel dimension
mask_s = mask_s.mean(dim=-3) # (..., freq, time)
mask_n = mask_n.mean(dim=-3) # (..., freq, time)
if self.psd_s.ndim == 1:
self.psd_s = psd_s
self.psd_n = psd_n
self.mask_sum_s = mask_s.sum(dim=-1)
self.mask_sum_n = mask_n.sum(dim=-1)
return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
else:
psd_s = self._get_updated_psd_speech(psd_s, mask_s)
psd_n = self._get_updated_psd_noise(psd_n, mask_n)
self.psd_s = psd_s
self.psd_n = psd_n
self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1)
self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1)
return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor:
r"""Update psd of speech recursively.
Args:
psd_s (torch.Tensor): psd matrix of target speech
mask_s (torch.Tensor): T-F mask of target speech
Returns:
torch.Tensor: the updated psd of speech
"""
numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1))
denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1))
psd_s = self.psd_s * numerator[..., None, None] + psd_s * denominator[..., None, None]
return psd_s
def _get_updated_psd_noise(self, psd_n: torch.Tensor, mask_n: torch.Tensor) -> torch.Tensor:
r"""Update psd of noise recursively.
Args:
psd_n (torch.Tensor): psd matrix of target noise
mask_n (torch.Tensor): T-F mask of target noise
Returns:
torch.Tensor: the updated psd of noise
"""
numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1))
denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1))
psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None]
return psd_n
def _get_mvdr_vector(
self,
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = 'ref_channel',
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
r"""Compute beamforming vector by the reference channel selection method.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-8
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if diagonal_loading:
psd_n = self._tik_reg(psd_n, reg=diag_eps, eps=eps)
if solution == "ref_channel":
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (mat_trace(numerator)[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector = torch.einsum("...fec,...c->...fe", [ws, reference_vector])
else:
if solution == "stv_evd":
stv = self._get_steering_vector_evd(psd_s)
else:
stv = self._get_steering_vector_power(psd_s, psd_n, reference_vector)
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, stv).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator = torch.einsum("...d,...d->...", [stv.conj().squeeze(-1), numerator])
# normalzie the numerator
scale = stv.squeeze(-1)[..., self.ref_channel, None].conj()
beamform_vector = numerator * scale / (denominator.real.unsqueeze(-1) + eps)
return beamform_vector
def _get_steering_vector_evd(self, psd_s: torch.Tensor) -> torch.Tensor:
r"""Estimate the steering vector by eigenvalue decomposition.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
"""
w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel)
_, indices = torch.max(w.abs(), dim=-1, keepdim=True)
indices = indices.unsqueeze(-1)
stv = v.gather(-1, indices.expand(psd_s.shape[:-1] + (1,))) # (..., freq, channel, 1)
return stv
def _get_steering_vector_power(
self,
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor
) -> torch.Tensor:
r"""Estimate the steering vector by the power method.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension (..., freq, channel, channel)
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension (..., freq, channel, channel)
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, channel, 1)
"""
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector])
stv = stv.unsqueeze(-1)
stv = torch.matmul(phi, stv)
stv = torch.matmul(psd_s, stv)
return stv
def _apply_beamforming_vector(
self,
X: torch.Tensor,
beamform_vector: torch.Tensor
) -> torch.Tensor:
r"""Apply the beamforming weight to the noisy STFT
Args:
X (torch.tensor): multi-channel noisy STFT
Tensor of dimension (..., channel, freq, time)
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension (..., freq, channel)
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension (..., freq, time)
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
Y = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), X])
return Y
def _tik_reg(
self,
mat: torch.Tensor,
reg: float = 1e-7,
eps: float = 1e-8
) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: 1e-8)
Returns:
torch.Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
with torch.no_grad():
epsilon = mat_trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye[..., :, :]
return mat
def forward(self, X: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Perform MVDR beamforming.
Args:
X (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension (..., channel, freq, time)
mask_s (torch.Tensor): Time-Frequency mask of target speech
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise
Tensor of dimension (..., freq, time) if multi_mask is ``False``
or or dimension (..., channel, freq, time) if multi_mask is ``True``
(Default: None)
Returns:
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension (..., freq, time)
"""
if X.ndim < 3:
raise ValueError(
f"Expected at least 3D tensor (..., channel, freq, time). Found: {X.shape}"
)
if X.dtype != torch.cdouble:
raise ValueError(
f"The type of the input STFT tensor must be ``torch.cdouble``. Found: {X.dtype}"
)
if mask_n is None:
mask_n = 1 - mask_s
shape = X.size()
# pack batch
X = X.reshape(-1, shape[-3], shape[-2], shape[-1])
if self.multi_mask:
mask_s = mask_s.reshape(-1, shape[-3], shape[-2], shape[-1])
mask_n = mask_n.reshape(-1, shape[-3], shape[-2], shape[-1])
else:
mask_s = mask_s.reshape(-1, shape[-2], shape[-1])
mask_n = mask_n.reshape(-1, shape[-2], shape[-1])
psd_s = self.psd(X, mask_s) # (..., freq, time, channel, channel)
psd_n = self.psd(X, mask_n) # (..., freq, time, channel, channel)
u = torch.zeros(
X.size()[:-2],
device=X.device,
dtype=torch.cdouble
) # (..., channel)
u[..., self.ref_channel].fill_(1)
if self.online:
w_mvdr = self._get_updated_mvdr_vector(
psd_s,
psd_n,
mask_s,
mask_n,
u,
self.solution,
self.diag_loading,
self.diag_eps
)
else:
w_mvdr = self._get_mvdr_vector(
psd_s,
psd_n,
u,
self.solution,
self.diag_loading,
self.diag_eps
)
Y = self._apply_beamforming_vector(X, w_mvdr)
# unpack batch
Y = Y.reshape(shape[:-3] + shape[-2:])
return Y
...@@ -2,14 +2,6 @@ from typing import Optional ...@@ -2,14 +2,6 @@ from typing import Optional
import numpy as np import numpy as np
import torch import torch
from beamforming.mvdr import PSD
from parameterized import parameterized, param
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
)
def psd_numpy( def psd_numpy(
...@@ -33,28 +25,3 @@ def psd_numpy( ...@@ -33,28 +25,3 @@ def psd_numpy(
psd = psd.sum(axis=-3) psd = psd.sum(axis=-3)
return torch.tensor(psd, dtype=torch.cdouble) return torch.tensor(psd, dtype=torch.cdouble)
class TransformsTestBase(TestBaseMixin):
@parameterized.expand([
param(0.5, 1, True, False),
param(0.5, 1, None, False),
param(1, 4, True, True),
param(1, 6, None, True),
])
def test_psd(self, duration, channel, mask, multi_mask):
"""Providing dtype changes the kernel cache dtype"""
transform = PSD(multi_mask)
waveform = get_whitenoise(sample_rate=8000, duration=duration, n_channels=channel)
spectrogram = get_spectrogram(waveform, n_fft=400) # (channel, freq, time)
spectrogram = spectrogram.to(torch.cdouble)
if mask is not None:
if multi_mask:
mask = torch.rand(spectrogram.shape[-3:])
else:
mask = torch.rand(spectrogram.shape[-2:])
psd_np = psd_numpy(spectrogram.detach().numpy(), mask.detach().numpy(), multi_mask)
else:
psd_np = psd_numpy(spectrogram.detach().numpy(), mask, multi_mask)
psd = transform(spectrogram, mask)
self.assertEqual(psd, psd_np, atol=1e-5, rtol=1e-5)
from torchaudio_unittest.common_utils import PytorchTestCase
from .autograd_test_impl import AutogradTestMixin
class AutogradCPUTest(AutogradTestMixin, PytorchTestCase):
device = 'cpu'
class AutogradRNNTCPUTest(PytorchTestCase):
device = 'cpu'
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoCuda,
)
from .autograd_test_impl import AutogradTestMixin
@skipIfNoCuda
class AutogradCUDATest(AutogradTestMixin, PytorchTestCase):
device = 'cuda'
@skipIfNoCuda
class AutogradRNNTCUDATest(PytorchTestCase):
device = 'cuda'
from typing import List
import torch
from beamforming.mvdr import PSD, MVDR
from parameterized import parameterized, param
from torch.autograd import gradcheck, gradgradcheck
from torchaudio_unittest.common_utils import (
TestBaseMixin,
get_whitenoise,
get_spectrogram,
)
class AutogradTestMixin(TestBaseMixin):
def assert_grad(
self,
transform: torch.nn.Module,
inputs: List[torch.Tensor],
*,
nondet_tol: float = 0.0,
):
transform = transform.to(dtype=torch.float64, device=self.device)
# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
# `torch.cdouble`, when the default eps and tolerance values are used.
inputs_ = []
for i in inputs:
if torch.is_tensor(i):
i = i.to(
dtype=torch.cdouble if i.is_complex() else torch.double,
device=self.device)
i.requires_grad = True
inputs_.append(i)
assert gradcheck(transform, inputs_)
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)
def test_psd(self):
transform = PSD()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400)
self.assert_grad(transform, [spectrogram])
@parameterized.expand([
[True],
[False],
])
def test_psd_with_mask(self, multi_mask):
transform = PSD(multi_mask=multi_mask)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400)
if multi_mask:
mask = torch.rand(spectrogram.shape[-3:])
else:
mask = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask])
@parameterized.expand([
param(solution="ref_channel"),
param(solution="stv_power"),
# evd will fail since the eigenvalues are not distinct
# param(solution="stv_evd"),
])
def test_mvdr(self, solution):
transform = MVDR(solution=solution)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400)
mask = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask])
"""Test numerical consistency among single input and batched input."""
import torch
from beamforming.mvdr import PSD, MVDR
from parameterized import parameterized
from torchaudio_unittest import common_utils
class TestTransforms(common_utils.TorchaudioTestCase):
def test_batch_PSD(self):
spec = torch.rand((4, 6, 201, 100), dtype=torch.cdouble)
# Single then transform then batch
expected = []
for i in range(4):
expected.append(PSD()(spec[i]))
expected = torch.stack(expected)
# Batch then transform
computed = PSD()(spec)
self.assertEqual(computed, expected)
def test_batch_PSD_with_mask(self):
spec = torch.rand((4, 6, 201, 100), dtype=torch.cdouble)
mask = torch.rand((4, 201, 100))
# Single then transform then batch
expected = []
for i in range(4):
expected.append(PSD()(spec[i], mask[i]))
expected = torch.stack(expected)
# Batch then transform
computed = PSD()(spec, mask)
self.assertEqual(computed, expected)
@parameterized.expand([
[True],
[False],
])
def test_MVDR(self, multi_mask):
spec = torch.rand((4, 6, 201, 100), dtype=torch.cdouble)
if multi_mask:
mask = torch.rand((4, 6, 201, 100))
else:
mask = torch.rand((4, 201, 100))
# Single then transform then batch
expected = []
for i in range(4):
expected.append(MVDR(multi_mask=multi_mask)(spec[i], mask[i]))
expected = torch.stack(expected)
# Batch then transform
computed = MVDR(multi_mask=multi_mask)(spec, mask)
self.assertEqual(computed, expected)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat64Only
class TestTransformsFloat32(Transforms, PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat64Only
@skipIfNoCuda
class TestTransformsFloat32(Transforms, PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
"""Test suites for jit-ability and its numerical compatibility"""
import torch
from beamforming.mvdr import PSD, MVDR
from parameterized import parameterized, param
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
TestBaseMixin,
)
class Transforms(TempDirMixin, TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
def _assert_consistency_complex(self, transform, tensors):
assert tensors[0].is_complex()
tensors = [tensor.to(device=self.device, dtype=self.complex_dtype) for tensor in tensors]
transform = transform.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
output = transform(*tensors)
ts_output = ts_transform(*tensors)
self.assertEqual(ts_output, output)
def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
self._assert_consistency_complex(PSD(), (spectrogram,))
def test_PSD_with_mask(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
mask = torch.rand(spectrogram.shape[-2:])
self._assert_consistency_complex(PSD(), (spectrogram, mask))
class TransformsFloat64Only(TestBaseMixin):
@parameterized.expand([
param(solution="ref_channel", online=True),
param(solution="stv_evd", online=True),
param(solution="stv_power", online=True),
param(solution="ref_channel", online=False),
param(solution="stv_evd", online=False),
param(solution="stv_power", online=False),
])
def test_MVDR(self, solution, online):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
mask = torch.rand(spectrogram.shape[-2:])
self._assert_consistency_complex(
MVDR(solution=solution, online=online),
(spectrogram, mask)
)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from . transforms_test_impl import TransformsTestBase
class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
dtype = torch.float32
class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
dtype = torch.float64
import torch
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoCuda,
)
from . transforms_test_impl import TransformsTestBase
@skipIfNoCuda
class TransformsCPUFloat32Test(TransformsTestBase, PytorchTestCase):
device = 'cuda'
dtype = torch.float32
@skipIfNoCuda
class TransformsCPUFloat64Test(TransformsTestBase, PytorchTestCase):
device = 'cpu'
dtype = torch.float64
...@@ -262,6 +262,42 @@ class AutogradTestMixin(TestBaseMixin): ...@@ -262,6 +262,42 @@ class AutogradTestMixin(TestBaseMixin):
spectrogram = torch.view_as_real(spectrogram) spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram]) self.assert_grad(transform, [spectrogram])
def test_psd(self):
transform = T.PSD()
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400)
self.assert_grad(transform, [spectrogram])
@parameterized.expand([
[True],
[False],
])
def test_psd_with_mask(self, multi_mask):
transform = T.PSD(multi_mask=multi_mask)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400)
if multi_mask:
mask = torch.rand(spectrogram.shape[-3:])
else:
mask = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask])
@parameterized.expand([
"ref_channel",
# stv_power test time too long, comment for now
# "stv_power",
# stv_evd will fail since the eigenvalues are not distinct
# "stv_evd",
])
def test_mvdr(self, solution):
transform = T.MVDR(solution=solution)
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
spectrogram = get_spectrogram(waveform, n_fft=400)
mask_s = torch.rand(spectrogram.shape[-2:])
mask_n = torch.rand(spectrogram.shape[-2:])
self.assert_grad(transform, [spectrogram, mask_s, mask_n])
class AutogradTestFloat32(TestBaseMixin): class AutogradTestFloat32(TestBaseMixin):
def assert_grad( def assert_grad(
......
...@@ -175,3 +175,54 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -175,3 +175,54 @@ class TestTransforms(common_utils.TorchaudioTestCase):
transform = T.PitchShift(sample_rate, n_steps, n_fft=400) transform = T.PitchShift(sample_rate, n_steps, n_fft=400)
self.assert_batch_consistency(transform, waveform) self.assert_batch_consistency(transform, waveform)
def test_batch_PSD(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1])
transform = T.PSD()
self.assert_batch_consistency(transform, specgram)
def test_batch_PSD_with_mask(self):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.to(torch.double)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1])
mask = torch.rand((3, specgram.shape[-2], specgram.shape[-1]))
transform = T.PSD()
# Single then transform then batch
expected = [transform(specgram[i], mask[i]) for i in range(3)]
expected = torch.stack(expected)
# Batch then transform
computed = transform(specgram, mask)
self.assertEqual(computed, expected)
@parameterized.expand([
[True],
[False],
])
def test_MVDR(self, multi_mask):
waveform = common_utils.get_whitenoise(sample_rate=8000, duration=1, n_channels=6)
waveform = waveform.to(torch.double)
specgram = common_utils.get_spectrogram(waveform, n_fft=400)
specgram = specgram.reshape(3, 2, specgram.shape[-2], specgram.shape[-1])
if multi_mask:
mask_s = torch.rand((3, 2, specgram.shape[-2], specgram.shape[-1]))
mask_n = torch.rand((3, 2, specgram.shape[-2], specgram.shape[-1]))
else:
mask_s = torch.rand((3, specgram.shape[-2], specgram.shape[-1]))
mask_n = torch.rand((3, specgram.shape[-2], specgram.shape[-1]))
transform = T.MVDR(multi_mask=multi_mask)
# Single then transform then batch
expected = [transform(specgram[i], mask_s[i], mask_n[i]) for i in range(3)]
expected = torch.stack(expected)
# Batch then transform
computed = transform(specgram, mask_s, mask_n)
self.assertEqual(computed, expected)
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only
class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
...@@ -9,6 +9,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): ...@@ -9,6 +9,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
device = torch.device('cpu') device = torch.device('cpu')
class TestTransformsFloat64(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsFloat32Only from .torchscript_consistency_impl import Transforms, TransformsFloat32Only, TransformsFloat64Only
@skipIfNoCuda @skipIfNoCuda
...@@ -11,6 +11,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase): ...@@ -11,6 +11,6 @@ class TestTransformsFloat32(Transforms, TransformsFloat32Only, PytorchTestCase):
@skipIfNoCuda @skipIfNoCuda
class TestTransformsFloat64(Transforms, PytorchTestCase): class TestTransformsFloat64(Transforms, TransformsFloat64Only, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
...@@ -24,7 +24,7 @@ class Transforms(TestBaseMixin): ...@@ -24,7 +24,7 @@ class Transforms(TestBaseMixin):
ts_output = ts_transform(tensor, *args) ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False): def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args):
assert tensor.is_complex() assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype) tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype) transform = transform.to(device=self.device, dtype=self.dtype)
...@@ -33,9 +33,8 @@ class Transforms(TestBaseMixin): ...@@ -33,9 +33,8 @@ class Transforms(TestBaseMixin):
if test_pseudo_complex: if test_pseudo_complex:
tensor = torch.view_as_real(tensor) tensor = torch.view_as_real(tensor)
output = transform(tensor, *args)
output = transform(tensor) ts_output = ts_transform(tensor, *args)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output) self.assertEqual(ts_output, output)
def test_Spectrogram(self): def test_Spectrogram(self):
...@@ -152,6 +151,19 @@ class Transforms(TestBaseMixin): ...@@ -152,6 +151,19 @@ class Transforms(TestBaseMixin):
waveform waveform
) )
def test_PSD(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(self.device)
self._assert_consistency_complex(T.PSD(), spectrogram)
def test_PSD_with_mask(self):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(self.device)
mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, False, mask)
class TransformsFloat32Only(TestBaseMixin): class TransformsFloat32Only(TestBaseMixin):
def test_rnnt_loss(self): def test_rnnt_loss(self):
...@@ -167,3 +179,24 @@ class TransformsFloat32Only(TestBaseMixin): ...@@ -167,3 +179,24 @@ class TransformsFloat32Only(TestBaseMixin):
target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32) target_lengths = torch.tensor([2], device=tensor.device, dtype=torch.int32)
self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths) self._assert_consistency(T.RNNTLoss(), logits, targets, logit_lengths, target_lengths)
class TransformsFloat64Only(TestBaseMixin):
@parameterized.expand([
["ref_channel", True],
["stv_evd", True],
["stv_power", True],
["ref_channel", False],
["stv_evd", False],
["stv_power", False],
])
def test_MVDR(self, solution, online):
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=4)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(device=self.device, dtype=torch.cdouble)
mask_s = torch.rand(spectrogram.shape[-2:], device=self.device)
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(
T.MVDR(solution=solution, online=online),
spectrogram, False, mask_s, mask_n
)
...@@ -7,6 +7,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -7,6 +7,7 @@ from torchaudio_unittest.common_utils import (
get_spectrogram, get_spectrogram,
nested_params, nested_params,
) )
from torchaudio_unittest.common_utils.psd_utils import psd_numpy
def _get_ratio(mat): def _get_ratio(mat):
...@@ -108,3 +109,26 @@ class TransformsTestBase(TestBaseMixin): ...@@ -108,3 +109,26 @@ class TransformsTestBase(TestBaseMixin):
transformed = s.forward(waveform) transformed = s.forward(waveform)
restored = inv_s.forward(transformed, length=waveform.shape[-1]) restored = inv_s.forward(transformed, length=waveform.shape[-1])
self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6) self.assertEqual(waveform, restored, atol=1e-6, rtol=1e-6)
@parameterized.expand([
param(0.5, 1, True, False),
param(0.5, 1, None, False),
param(1, 4, True, True),
param(1, 6, None, True),
])
def test_psd(self, duration, channel, mask, multi_mask):
"""Providing dtype changes the kernel cache dtype"""
transform = T.PSD(multi_mask)
waveform = get_whitenoise(sample_rate=8000, duration=duration, n_channels=channel)
spectrogram = get_spectrogram(waveform, n_fft=400) # (channel, freq, time)
spectrogram = spectrogram.to(torch.cdouble)
if mask is not None:
if multi_mask:
mask = torch.rand(spectrogram.shape[-3:])
else:
mask = torch.rand(spectrogram.shape[-2:])
psd_np = psd_numpy(spectrogram.detach().numpy(), mask.detach().numpy(), multi_mask)
else:
psd_np = psd_numpy(spectrogram.detach().numpy(), mask, multi_mask)
psd = transform(spectrogram, mask)
self.assertEqual(psd, psd_np, atol=1e-5, rtol=1e-5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment