"...git@developer.sourcefind.cn:OpenDAS/lmdeploy.git" did not exist on "2a4754785144a08f1e1feeb11fad87bbd6e41610"
Commit f5036c71 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor MVDR module (#2383)

Summary:
- Use `apply_beamforming`, `rtf_evd`, `rtf_power`, `mvdr_weights_souden`, `mvdr_weights_rtf` methods under `torchaudio.functional` to replace the class methods.
- Refactor docstrings in `PSD` and `MVDR`.
- Put `_get_mvdr_vector` outside of `MVDR` class as it doesn't call self methods inside.
- Since MVDR uses einsum for matrix operations, packing and unpacking batches are not necessary. It can be tested by the [batch_consistency_test](https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/transforms/batch_consistency_test.py#L202). Removed it from the code.

Pull Request resolved: https://github.com/pytorch/audio/pull/2383

Reviewed By: carolineechen, mthrok

Differential Revision: D36338373

Pulled By: nateanl

fbshipit-source-id: a48a6ae2825657e5967a19656245596cdf037c5f
parent 09639680
...@@ -1748,17 +1748,18 @@ def psd( ...@@ -1748,17 +1748,18 @@ def psd(
.. properties:: Autograd TorchScript .. properties:: Autograd TorchScript
Args: Args:
specgram (Tensor): Multi-channel complex-valued spectrum. specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)` Tensor with dimensions `(..., channel, freq, time)`.
mask (Tensor or None, optional): Real-valued time-frequency mask mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
for normalization. Tensor of dimension `(..., freq, time)` Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: ``None``) (Default: ``None``)
normalize (bool, optional): whether to normalize the mask along the time dimension. (Default: ``True``) normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-10``) eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
Returns: Returns:
Tensor: The complex-valued PSD matrix of the input spectrum. torch.Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)` Tensor with dimensions `(..., freq, channel, channel)`
""" """
specgram = specgram.transpose(-3, -2) # shape (freq, channel, time) specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
# outer product: # outer product:
...@@ -1780,14 +1781,14 @@ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> t ...@@ -1780,14 +1781,14 @@ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> t
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions. r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args: Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)` input (torch.Tensor): Tensor with dimensions `(..., channel, channel)`.
dim1 (int, optional): the first dimension of the diagonal matrix dim1 (int, optional): The first dimension of the diagonal matrix.
(Default: -1) (Default: ``-1``)
dim2 (int, optional): the second dimension of the diagonal matrix dim2 (int, optional): The second dimension of the diagonal matrix.
(Default: -2) (Default: ``-2``)
Returns: Returns:
Tensor: trace of the input Tensor Tensor: The trace of the input Tensor.
""" """
assert input.ndim >= 2, "The dimension of the tensor must be at least 2." assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same." assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
...@@ -1799,12 +1800,12 @@ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.T ...@@ -1799,12 +1800,12 @@ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.T
"""Perform Tikhonov regularization (only modifying real part). """Perform Tikhonov regularization (only modifying real part).
Args: Args:
mat (torch.Tensor): input matrix (..., channel, channel) mat (torch.Tensor): Input matrix with dimensions `(..., channel, channel)`.
reg (float, optional): regularization factor (Default: 1e-8) reg (float, optional): Regularization factor. (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``) eps (float, optional): Value to avoid the correlation matrix is all-zero. (Default: ``1e-8``)
Returns: Returns:
Tensor: regularized matrix (..., channel, channel) Tensor: Regularized matrix with dimensions `(..., channel, channel)`.
""" """
# Add eps # Add eps
C = mat.size(-1) C = mat.size(-1)
......
...@@ -11,23 +11,45 @@ from torchaudio import functional as F ...@@ -11,23 +11,45 @@ from torchaudio import functional as F
__all__ = [] __all__ = []
def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor: def _get_mvdr_vector(
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions. psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = "ref_channel",
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
r"""Compute the MVDR beamforming weights with ``solution`` argument.
Args: Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)` psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
dim1 (int, optional): the first dimension of the diagonal matrix Tensor with dimensions `(..., freq, channel, channel)`.
(Default: -1) psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
dim2 (int, optional): the second dimension of the diagonal matrix Tensor with dimensions `(..., freq, channel, channel)`.
(Default: -2) reference_vector (torch.Tensor): one-hot reference channel matrix.
solution (str, optional): Solution to compute the MVDR beamforming weights.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns: Returns:
torch.Tensor: trace of the input Tensor torch.Tensor: the mvdr beamforming weight matrix
""" """
assert input.ndim >= 2, "The dimension of the tensor must be at least 2." if solution == "ref_channel":
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same." beamform_vector = F.mvdr_weights_souden(psd_s, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2) else:
return input.sum(dim=-1) if solution == "stv_evd":
stv = F.rtf_evd(psd_s)
else:
stv = F.rtf_power(psd_s, psd_n, reference_vector, diagonal_loading=diagonal_loading, diag_eps=diag_eps)
beamform_vector = F.mvdr_weights_rtf(stv, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
return beamform_vector
class PSD(torch.nn.Module): class PSD(torch.nn.Module):
...@@ -38,9 +60,9 @@ class PSD(torch.nn.Module): ...@@ -38,9 +60,9 @@ class PSD(torch.nn.Module):
.. properties:: Autograd TorchScript .. properties:: Autograd TorchScript
Args: Args:
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``) multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
normalize (bool, optional): whether normalize the mask along the time dimension. normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-15) eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
""" """
def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15): def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15):
...@@ -52,32 +74,23 @@ class PSD(torch.nn.Module): ...@@ -52,32 +74,23 @@ class PSD(torch.nn.Module):
def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None): def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None):
""" """
Args: Args:
specgram (torch.Tensor): multi-channel complex-valued STFT matrix. specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)` Tensor with dimensions `(..., channel, freq, time)`.
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization. mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` or Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
of dimension `(..., channel, freq, time)` if multi_mask is ``True`` with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: ``None``)
Returns: Returns:
Tensor: PSD matrix of the input STFT matrix. torch.Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)` Tensor with dimensions `(..., freq, channel, channel)`
""" """
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
psd = torch.einsum("...cft,...eft->...ftce", [specgram, specgram.conj()])
if mask is not None: if mask is not None:
if self.multi_mask: if self.multi_mask:
# Averaging mask along channel dimension # Averaging mask along channel dimension
mask = mask.mean(dim=-3) # (..., freq, time) mask = mask.mean(dim=-3) # (..., freq, time)
psd = F.psd(specgram, mask, self.normalize, self.eps)
# Normalized mask along time dimension:
if self.normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + self.eps)
psd = psd * mask.unsqueeze(-1).unsqueeze(-1)
psd = psd.sum(dim=-3)
return psd return psd
...@@ -128,16 +141,16 @@ class MVDR(torch.nn.Module): ...@@ -128,16 +141,16 @@ class MVDR(torch.nn.Module):
PSD matrices of speech and noise, respectively. PSD matrices of speech and noise, respectively.
Args: Args:
ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``) ref_channel (int, optional): Reference channel for beamforming. (Default: ``0``)
solution (str, optional): the solution to get MVDR weight. solution (str, optional): Solution to compute the MVDR beamforming weights.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``) Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``) multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
diag_loading (bool, optional): whether apply diagonal loading on the psd matrix of noise. diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to the covariance matrix
(Default: ``True``) of the noise. (Default: ``True``)
diag_eps (float, optional): the coefficient multipied to the identity matrix for diagonal loading. diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
(Default: 1e-7) It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
online (bool, optional): whether to update the mvdr vector based on the previous psd matrices. online (bool, optional): If ``True``, updates the MVDR beamforming weights based on
(Default: ``False``) the previous covarience matrices. (Default: ``False``)
Note: Note:
To improve the numerical stability, the input spectrogram will be converted to double precision To improve the numerical stability, the input spectrogram will be converted to double precision
...@@ -196,21 +209,28 @@ class MVDR(torch.nn.Module): ...@@ -196,21 +209,28 @@ class MVDR(torch.nn.Module):
r"""Recursively update the MVDR beamforming vector. r"""Recursively update the MVDR beamforming vector.
Args: Args:
psd_s (torch.Tensor): psd matrix of target speech psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
psd_n (torch.Tensor): psd matrix of noise Tensor with dimensions `(..., freq, channel, channel)`.
mask_s (torch.Tensor): T-F mask of target speech psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
mask_n (torch.Tensor): T-F mask of noise Tensor with dimensions `(..., freq, channel, channel)`.
reference_vector (torch.Tensor): one-hot reference channel matrix mask_s (torch.Tensor): Time-Frequency mask of the target speech.
solution (str, optional): the solution to estimate the beamforming weight Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
(Default: ``ref_channel``) or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
reference_vector (torch.Tensor): One-hot reference channel matrix.
solution (str, optional): Solution to compute the MVDR beamforming weights.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``) (Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
(Default: 1e-7) It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8) eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns: Returns:
Tensor: the mvdr beamforming weight matrix torch.Tensor: The MVDR beamforming weight matrix.
""" """
if self.multi_mask: if self.multi_mask:
# Averaging mask along channel dimension # Averaging mask along channel dimension
...@@ -221,7 +241,7 @@ class MVDR(torch.nn.Module): ...@@ -221,7 +241,7 @@ class MVDR(torch.nn.Module):
self.psd_n = psd_n self.psd_n = psd_n
self.mask_sum_s = mask_s.sum(dim=-1) self.mask_sum_s = mask_s.sum(dim=-1)
self.mask_sum_n = mask_n.sum(dim=-1) self.mask_sum_n = mask_n.sum(dim=-1)
return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps) return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
else: else:
psd_s = self._get_updated_psd_speech(psd_s, mask_s) psd_s = self._get_updated_psd_speech(psd_s, mask_s)
psd_n = self._get_updated_psd_noise(psd_n, mask_n) psd_n = self._get_updated_psd_noise(psd_n, mask_n)
...@@ -229,17 +249,19 @@ class MVDR(torch.nn.Module): ...@@ -229,17 +249,19 @@ class MVDR(torch.nn.Module):
self.psd_n = psd_n self.psd_n = psd_n
self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1) self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1)
self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1) self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1)
return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps) return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor: def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor:
r"""Update psd of speech recursively. r"""Update psd of speech recursively.
Args: Args:
psd_s (torch.Tensor): psd matrix of target speech psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
mask_s (torch.Tensor): T-F mask of target speech Tensor with dimensions `(..., freq, channel, channel)`.
mask_s (torch.Tensor): Time-Frequency mask of the target speech.
Tensor with dimensions `(..., freq, time)`.
Returns: Returns:
torch.Tensor: the updated psd of speech torch.Tensor: The updated PSD matrix of target speech.
""" """
numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1)) numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1))
denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1)) denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1))
...@@ -250,162 +272,37 @@ class MVDR(torch.nn.Module): ...@@ -250,162 +272,37 @@ class MVDR(torch.nn.Module):
r"""Update psd of noise recursively. r"""Update psd of noise recursively.
Args: Args:
psd_n (torch.Tensor): psd matrix of target noise psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
mask_n (torch.Tensor): T-F mask of target noise Tensor with dimensions `(..., freq, channel, channel)`.
mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
Tensor with dimensions `(..., freq, time)`.
Returns: Returns:
torch.Tensor: the updated psd of noise torch.Tensor: The updated PSD matrix of noise.
""" """
numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1)) numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1))
denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1)) denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1))
psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None] psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None]
return psd_n return psd_n
def _get_mvdr_vector(
self,
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = "ref_channel",
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
r"""Compute beamforming vector by the reference channel selection method.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-8
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if diagonal_loading:
psd_n = self._tik_reg(psd_n, reg=diag_eps, eps=eps)
if solution == "ref_channel":
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_get_mat_trace(numerator)[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector = torch.einsum("...fec,...c->...fe", [ws, reference_vector])
else:
if solution == "stv_evd":
stv = self._get_steering_vector_evd(psd_s)
else:
stv = self._get_steering_vector_power(psd_s, psd_n, reference_vector)
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, stv).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator = torch.einsum("...d,...d->...", [stv.conj().squeeze(-1), numerator])
# normalzie the numerator
scale = stv.squeeze(-1)[..., self.ref_channel, None].conj()
beamform_vector = numerator / (denominator.real.unsqueeze(-1) + eps) * scale
return beamform_vector
def _get_steering_vector_evd(self, psd_s: torch.Tensor) -> torch.Tensor:
r"""Estimate the steering vector by eigenvalue decomposition.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension `(..., freq, channel, channel)`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension `(..., freq, channel, 1)`
"""
w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel)
_, indices = torch.max(w.abs(), dim=-1, keepdim=True)
indices = indices.unsqueeze(-1)
stv = v.gather(-1, indices.expand(psd_s.shape[:-1] + (1,))) # (..., freq, channel, 1)
return stv
def _get_steering_vector_power(
self, psd_s: torch.Tensor, psd_n: torch.Tensor, reference_vector: torch.Tensor
) -> torch.Tensor:
r"""Estimate the steering vector by the power method.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension `(..., freq, channel, channel)`
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension `(..., freq, channel, channel)`
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension `(..., freq, channel, 1)`
"""
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector])
stv = stv.unsqueeze(-1)
stv = torch.matmul(phi, stv)
stv = torch.matmul(psd_s, stv)
return stv
def _apply_beamforming_vector(self, specgram: torch.Tensor, beamform_vector: torch.Tensor) -> torch.Tensor:
r"""Apply the beamforming weight to the noisy STFT
Args:
specgram (torch.tensor): multi-channel noisy STFT
Tensor of dimension `(..., channel, freq, time)`
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension `(..., freq, channel)`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension `(..., freq, time)`
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram])
return specgram_enhanced
def _tik_reg(self, mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: 1e-8)
Returns:
torch.Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
with torch.no_grad():
epsilon = _get_mat_trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye[..., :, :]
return mat
def forward( def forward(
self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None
) -> torch.Tensor: ) -> torch.Tensor:
"""Perform MVDR beamforming. """Perform MVDR beamforming.
Args: Args:
specgram (torch.Tensor): the multi-channel STF of the noisy speech. specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)` Tensor with dimensions `(..., channel, freq, time)`
mask_s (torch.Tensor): Time-Frequency mask of target speech. mask_s (torch.Tensor): Time-Frequency mask of target speech.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
or or dimension `(..., channel, freq, time)` if multi_mask is ``True`` or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise. mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
or or dimension `(..., channel, freq, time)` if multi_mask is ``True`` or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: None) (Default: None)
Returns: Returns:
torch.Tensor: The single-channel STFT of the enhanced speech. torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
Tensor of dimension `(..., freq, time)`
""" """
dtype = specgram.dtype dtype = specgram.dtype
if specgram.ndim < 3: if specgram.ndim < 3:
...@@ -422,17 +319,6 @@ class MVDR(torch.nn.Module): ...@@ -422,17 +319,6 @@ class MVDR(torch.nn.Module):
warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.") warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.")
mask_n = 1 - mask_s mask_n = 1 - mask_s
shape = specgram.size()
# pack batch
specgram = specgram.reshape(-1, shape[-3], shape[-2], shape[-1])
if self.multi_mask:
mask_s = mask_s.reshape(-1, shape[-3], shape[-2], shape[-1])
mask_n = mask_n.reshape(-1, shape[-3], shape[-2], shape[-1])
else:
mask_s = mask_s.reshape(-1, shape[-2], shape[-1])
mask_n = mask_n.reshape(-1, shape[-2], shape[-1])
psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel) psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel) psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)
...@@ -444,12 +330,9 @@ class MVDR(torch.nn.Module): ...@@ -444,12 +330,9 @@ class MVDR(torch.nn.Module):
psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps
) )
else: else:
w_mvdr = self._get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps) w_mvdr = _get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
specgram_enhanced = self._apply_beamforming_vector(specgram, w_mvdr)
# unpack batch specgram_enhanced = F.apply_beamforming(w_mvdr, specgram)
specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:])
return specgram_enhanced.to(dtype) return specgram_enhanced.to(dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment