Commit f5036c71 authored by Zhaoheng Ni's avatar Zhaoheng Ni Committed by Facebook GitHub Bot
Browse files

Refactor MVDR module (#2383)

Summary:
- Use `apply_beamforming`, `rtf_evd`, `rtf_power`, `mvdr_weights_souden`, `mvdr_weights_rtf` methods under `torchaudio.functional` to replace the class methods.
- Refactor docstrings in `PSD` and `MVDR`.
- Put `_get_mvdr_vector` outside of `MVDR` class as it doesn't call self methods inside.
- Since MVDR uses einsum for matrix operations, packing and unpacking batches are not necessary. It can be tested by the [batch_consistency_test](https://github.com/pytorch/audio/blob/main/test/torchaudio_unittest/transforms/batch_consistency_test.py#L202). Removed it from the code.

Pull Request resolved: https://github.com/pytorch/audio/pull/2383

Reviewed By: carolineechen, mthrok

Differential Revision: D36338373

Pulled By: nateanl

fbshipit-source-id: a48a6ae2825657e5967a19656245596cdf037c5f
parent 09639680
......@@ -1748,17 +1748,18 @@ def psd(
.. properties:: Autograd TorchScript
Args:
specgram (Tensor): Multi-channel complex-valued spectrum.
Tensor of dimension `(..., channel, freq, time)`
mask (Tensor or None, optional): Real-valued time-frequency mask
for normalization. Tensor of dimension `(..., freq, time)`
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`.
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: ``None``)
normalize (bool, optional): whether to normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): a value added to the denominator in mask normalization. (Default: ``1e-10``)
normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
Returns:
Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor of dimension `(..., freq, channel, channel)`
torch.Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor with dimensions `(..., freq, channel, channel)`
"""
specgram = specgram.transpose(-3, -2) # shape (freq, channel, time)
# outer product:
......@@ -1780,14 +1781,14 @@ def _compute_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> t
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
input (torch.Tensor): Tensor with dimensions `(..., channel, channel)`.
dim1 (int, optional): The first dimension of the diagonal matrix.
(Default: ``-1``)
dim2 (int, optional): The second dimension of the diagonal matrix.
(Default: ``-2``)
Returns:
Tensor: trace of the input Tensor
Tensor: The trace of the input Tensor.
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
......@@ -1799,12 +1800,12 @@ def _tik_reg(mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.T
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: ``1e-8``)
mat (torch.Tensor): Input matrix with dimensions `(..., channel, channel)`.
reg (float, optional): Regularization factor. (Default: 1e-8)
eps (float, optional): Value to avoid the correlation matrix is all-zero. (Default: ``1e-8``)
Returns:
Tensor: regularized matrix (..., channel, channel)
Tensor: Regularized matrix with dimensions `(..., channel, channel)`.
"""
# Add eps
C = mat.size(-1)
......
......@@ -11,23 +11,45 @@ from torchaudio import functional as F
__all__ = []
def _get_mat_trace(input: torch.Tensor, dim1: int = -1, dim2: int = -2) -> torch.Tensor:
r"""Compute the trace of a Tensor along ``dim1`` and ``dim2`` dimensions.
def _get_mvdr_vector(
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = "ref_channel",
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
r"""Compute the MVDR beamforming weights with ``solution`` argument.
Args:
input (torch.Tensor): Tensor of dimension `(..., channel, channel)`
dim1 (int, optional): the first dimension of the diagonal matrix
(Default: -1)
dim2 (int, optional): the second dimension of the diagonal matrix
(Default: -2)
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
reference_vector (torch.Tensor): one-hot reference channel matrix.
solution (str, optional): Solution to compute the MVDR beamforming weights.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
torch.Tensor: trace of the input Tensor
torch.Tensor: the mvdr beamforming weight matrix
"""
assert input.ndim >= 2, "The dimension of the tensor must be at least 2."
assert input.shape[dim1] == input.shape[dim2], "The size of ``dim1`` and ``dim2`` must be the same."
input = torch.diagonal(input, 0, dim1=dim1, dim2=dim2)
return input.sum(dim=-1)
if solution == "ref_channel":
beamform_vector = F.mvdr_weights_souden(psd_s, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
else:
if solution == "stv_evd":
stv = F.rtf_evd(psd_s)
else:
stv = F.rtf_power(psd_s, psd_n, reference_vector, diagonal_loading=diagonal_loading, diag_eps=diag_eps)
beamform_vector = F.mvdr_weights_rtf(stv, psd_n, reference_vector, diagonal_loading, diag_eps, eps)
return beamform_vector
class PSD(torch.nn.Module):
......@@ -38,9 +60,9 @@ class PSD(torch.nn.Module):
.. properties:: Autograd TorchScript
Args:
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``)
normalize (bool, optional): whether normalize the mask along the time dimension.
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-15)
multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
normalize (bool, optional): If ``True``, normalize the mask along the time dimension. (Default: ``True``)
eps (float, optional): Value to add to the denominator in mask normalization. (Default: ``1e-15``)
"""
def __init__(self, multi_mask: bool = False, normalize: bool = True, eps: float = 1e-15):
......@@ -52,32 +74,23 @@ class PSD(torch.nn.Module):
def forward(self, specgram: torch.Tensor, mask: Optional[torch.Tensor] = None):
"""
Args:
specgram (torch.Tensor): multi-channel complex-valued STFT matrix.
Tensor of dimension `(..., channel, freq, time)`
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`.
mask (torch.Tensor or None, optional): Time-Frequency mask for normalization.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False`` or
of dimension `(..., channel, freq, time)` if multi_mask is ``True``
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False`` or
with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: ``None``)
Returns:
Tensor: PSD matrix of the input STFT matrix.
Tensor of dimension `(..., freq, channel, channel)`
torch.Tensor: The complex-valued PSD matrix of the input spectrum.
Tensor with dimensions `(..., freq, channel, channel)`
"""
# outer product:
# (..., ch_1, freq, time) x (..., ch_2, freq, time) -> (..., time, ch_1, ch_2)
psd = torch.einsum("...cft,...eft->...ftce", [specgram, specgram.conj()])
if mask is not None:
if self.multi_mask:
# Averaging mask along channel dimension
mask = mask.mean(dim=-3) # (..., freq, time)
psd = F.psd(specgram, mask, self.normalize, self.eps)
# Normalized mask along time dimension:
if self.normalize:
mask = mask / (mask.sum(dim=-1, keepdim=True) + self.eps)
psd = psd * mask.unsqueeze(-1).unsqueeze(-1)
psd = psd.sum(dim=-3)
return psd
......@@ -128,16 +141,16 @@ class MVDR(torch.nn.Module):
PSD matrices of speech and noise, respectively.
Args:
ref_channel (int, optional): the reference channel for beamforming. (Default: ``0``)
solution (str, optional): the solution to get MVDR weight.
ref_channel (int, optional): Reference channel for beamforming. (Default: ``0``)
solution (str, optional): Solution to compute the MVDR beamforming weights.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
multi_mask (bool, optional): whether to use multi-channel Time-Frequency masks. (Default: ``False``)
diag_loading (bool, optional): whether apply diagonal loading on the psd matrix of noise.
(Default: ``True``)
diag_eps (float, optional): the coefficient multipied to the identity matrix for diagonal loading.
(Default: 1e-7)
online (bool, optional): whether to update the mvdr vector based on the previous psd matrices.
(Default: ``False``)
multi_mask (bool, optional): If ``True``, only accepts multi-channel Time-Frequency masks. (Default: ``False``)
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to the covariance matrix
of the noise. (Default: ``True``)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
online (bool, optional): If ``True``, updates the MVDR beamforming weights based on
the previous covarience matrices. (Default: ``False``)
Note:
To improve the numerical stability, the input spectrogram will be converted to double precision
......@@ -196,21 +209,28 @@ class MVDR(torch.nn.Module):
r"""Recursively update the MVDR beamforming vector.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
mask_s (torch.Tensor): T-F mask of target speech
mask_n (torch.Tensor): T-F mask of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
mask_s (torch.Tensor): Time-Frequency mask of the target speech.
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
reference_vector (torch.Tensor): One-hot reference channel matrix.
solution (str, optional): Solution to compute the MVDR beamforming weights.
Options: [``ref_channel``, ``stv_evd``, ``stv_power``]. (Default: ``ref_channel``)
diagonal_loading (bool, optional): If ``True``, enables applying diagonal loading to ``psd_n``.
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. (Default: 1e-8)
diag_eps (float, optional): The coefficient multiplied to the identity matrix for diagonal loading.
It is only effective when ``diagonal_loading`` is set to ``True``. (Default: ``1e-7``)
eps (float, optional): Value to add to the denominator in the beamforming weight formula.
(Default: ``1e-8``)
Returns:
Tensor: the mvdr beamforming weight matrix
torch.Tensor: The MVDR beamforming weight matrix.
"""
if self.multi_mask:
# Averaging mask along channel dimension
......@@ -221,7 +241,7 @@ class MVDR(torch.nn.Module):
self.psd_n = psd_n
self.mask_sum_s = mask_s.sum(dim=-1)
self.mask_sum_n = mask_n.sum(dim=-1)
return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
else:
psd_s = self._get_updated_psd_speech(psd_s, mask_s)
psd_n = self._get_updated_psd_noise(psd_n, mask_n)
......@@ -229,17 +249,19 @@ class MVDR(torch.nn.Module):
self.psd_n = psd_n
self.mask_sum_s = self.mask_sum_s + mask_s.sum(dim=-1)
self.mask_sum_n = self.mask_sum_n + mask_n.sum(dim=-1)
return self._get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
return _get_mvdr_vector(psd_s, psd_n, reference_vector, solution, diagonal_loading, diag_eps, eps)
def _get_updated_psd_speech(self, psd_s: torch.Tensor, mask_s: torch.Tensor) -> torch.Tensor:
r"""Update psd of speech recursively.
Args:
psd_s (torch.Tensor): psd matrix of target speech
mask_s (torch.Tensor): T-F mask of target speech
psd_s (torch.Tensor): The complex-valued power spectral density (PSD) matrix of target speech.
Tensor with dimensions `(..., freq, channel, channel)`.
mask_s (torch.Tensor): Time-Frequency mask of the target speech.
Tensor with dimensions `(..., freq, time)`.
Returns:
torch.Tensor: the updated psd of speech
torch.Tensor: The updated PSD matrix of target speech.
"""
numerator = self.mask_sum_s / (self.mask_sum_s + mask_s.sum(dim=-1))
denominator = 1 / (self.mask_sum_s + mask_s.sum(dim=-1))
......@@ -250,162 +272,37 @@ class MVDR(torch.nn.Module):
r"""Update psd of noise recursively.
Args:
psd_n (torch.Tensor): psd matrix of target noise
mask_n (torch.Tensor): T-F mask of target noise
psd_n (torch.Tensor): The complex-valued power spectral density (PSD) matrix of noise.
Tensor with dimensions `(..., freq, channel, channel)`.
mask_n (torch.Tensor or None, optional): Time-Frequency mask of the noise.
Tensor with dimensions `(..., freq, time)`.
Returns:
torch.Tensor: the updated psd of noise
torch.Tensor: The updated PSD matrix of noise.
"""
numerator = self.mask_sum_n / (self.mask_sum_n + mask_n.sum(dim=-1))
denominator = 1 / (self.mask_sum_n + mask_n.sum(dim=-1))
psd_n = self.psd_n * numerator[..., None, None] + psd_n * denominator[..., None, None]
return psd_n
def _get_mvdr_vector(
self,
psd_s: torch.Tensor,
psd_n: torch.Tensor,
reference_vector: torch.Tensor,
solution: str = "ref_channel",
diagonal_loading: bool = True,
diag_eps: float = 1e-7,
eps: float = 1e-8,
) -> torch.Tensor:
r"""Compute beamforming vector by the reference channel selection method.
Args:
psd_s (torch.Tensor): psd matrix of target speech
psd_n (torch.Tensor): psd matrix of noise
reference_vector (torch.Tensor): one-hot reference channel matrix
solution (str, optional): the solution to estimate the beamforming weight
(Default: ``ref_channel``)
diagonal_loading (bool, optional): whether to apply diagonal loading to psd_n
(Default: ``True``)
diag_eps (float, optional): The coefficient multipied to the identity matrix for diagonal loading
(Default: 1e-7)
eps (float, optional): a value added to the denominator in mask normalization. Default: 1e-8
Returns:
torch.Tensor: the mvdr beamforming weight matrix
"""
if diagonal_loading:
psd_n = self._tik_reg(psd_n, reg=diag_eps, eps=eps)
if solution == "ref_channel":
numerator = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
# ws: (..., C, C) / (...,) -> (..., C, C)
ws = numerator / (_get_mat_trace(numerator)[..., None, None] + eps)
# h: (..., F, C_1, C_2) x (..., C_2) -> (..., F, C_1)
beamform_vector = torch.einsum("...fec,...c->...fe", [ws, reference_vector])
else:
if solution == "stv_evd":
stv = self._get_steering_vector_evd(psd_s)
else:
stv = self._get_steering_vector_power(psd_s, psd_n, reference_vector)
# numerator = psd_n.inv() @ stv
numerator = torch.linalg.solve(psd_n, stv).squeeze(-1) # (..., freq, channel)
# denominator = stv^H @ psd_n.inv() @ stv
denominator = torch.einsum("...d,...d->...", [stv.conj().squeeze(-1), numerator])
# normalzie the numerator
scale = stv.squeeze(-1)[..., self.ref_channel, None].conj()
beamform_vector = numerator / (denominator.real.unsqueeze(-1) + eps) * scale
return beamform_vector
def _get_steering_vector_evd(self, psd_s: torch.Tensor) -> torch.Tensor:
r"""Estimate the steering vector by eigenvalue decomposition.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension `(..., freq, channel, channel)`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension `(..., freq, channel, 1)`
"""
w, v = torch.linalg.eig(psd_s) # (..., freq, channel, channel)
_, indices = torch.max(w.abs(), dim=-1, keepdim=True)
indices = indices.unsqueeze(-1)
stv = v.gather(-1, indices.expand(psd_s.shape[:-1] + (1,))) # (..., freq, channel, 1)
return stv
def _get_steering_vector_power(
self, psd_s: torch.Tensor, psd_n: torch.Tensor, reference_vector: torch.Tensor
) -> torch.Tensor:
r"""Estimate the steering vector by the power method.
Args:
psd_s (torch.tensor): covariance matrix of speech
Tensor of dimension `(..., freq, channel, channel)`
psd_n (torch.Tensor): covariance matrix of noise
Tensor of dimension `(..., freq, channel, channel)`
reference_vector (torch.Tensor): one-hot reference channel matrix
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension `(..., freq, channel, 1)`
"""
phi = torch.linalg.solve(psd_n, psd_s) # psd_n.inv() @ psd_s
stv = torch.einsum("...fec,...c->...fe", [phi, reference_vector])
stv = stv.unsqueeze(-1)
stv = torch.matmul(phi, stv)
stv = torch.matmul(psd_s, stv)
return stv
def _apply_beamforming_vector(self, specgram: torch.Tensor, beamform_vector: torch.Tensor) -> torch.Tensor:
r"""Apply the beamforming weight to the noisy STFT
Args:
specgram (torch.tensor): multi-channel noisy STFT
Tensor of dimension `(..., channel, freq, time)`
beamform_vector (torch.Tensor): beamforming weight matrix
Tensor of dimension `(..., freq, channel)`
Returns:
torch.Tensor: the enhanced STFT
Tensor of dimension `(..., freq, time)`
"""
# (..., channel) x (..., channel, freq, time) -> (..., freq, time)
specgram_enhanced = torch.einsum("...fc,...cft->...ft", [beamform_vector.conj(), specgram])
return specgram_enhanced
def _tik_reg(self, mat: torch.Tensor, reg: float = 1e-7, eps: float = 1e-8) -> torch.Tensor:
"""Perform Tikhonov regularization (only modifying real part).
Args:
mat (torch.Tensor): input matrix (..., channel, channel)
reg (float, optional): regularization factor (Default: 1e-8)
eps (float, optional): a value to avoid the correlation matrix is all-zero (Default: 1e-8)
Returns:
torch.Tensor: regularized matrix (..., channel, channel)
"""
# Add eps
C = mat.size(-1)
eye = torch.eye(C, dtype=mat.dtype, device=mat.device)
with torch.no_grad():
epsilon = _get_mat_trace(mat).real[..., None, None] * reg
# in case that correlation_matrix is all-zero
epsilon = epsilon + eps
mat = mat + epsilon * eye[..., :, :]
return mat
def forward(
self, specgram: torch.Tensor, mask_s: torch.Tensor, mask_n: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Perform MVDR beamforming.
Args:
specgram (torch.Tensor): the multi-channel STF of the noisy speech.
Tensor of dimension `(..., channel, freq, time)`
specgram (torch.Tensor): Multi-channel complex-valued spectrum.
Tensor with dimensions `(..., channel, freq, time)`
mask_s (torch.Tensor): Time-Frequency mask of target speech.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False``
or or dimension `(..., channel, freq, time)` if multi_mask is ``True``
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
mask_n (torch.Tensor or None, optional): Time-Frequency mask of noise.
Tensor of dimension `(..., freq, time)` if multi_mask is ``False``
or or dimension `(..., channel, freq, time)` if multi_mask is ``True``
Tensor with dimensions `(..., freq, time)` if multi_mask is ``False``
or with dimensions `(..., channel, freq, time)` if multi_mask is ``True``.
(Default: None)
Returns:
torch.Tensor: The single-channel STFT of the enhanced speech.
Tensor of dimension `(..., freq, time)`
torch.Tensor: Single-channel complex-valued enhanced spectrum with dimensions `(..., freq, time)`.
"""
dtype = specgram.dtype
if specgram.ndim < 3:
......@@ -422,17 +319,6 @@ class MVDR(torch.nn.Module):
warnings.warn("``mask_n`` is not provided, use ``1 - mask_s`` as ``mask_n``.")
mask_n = 1 - mask_s
shape = specgram.size()
# pack batch
specgram = specgram.reshape(-1, shape[-3], shape[-2], shape[-1])
if self.multi_mask:
mask_s = mask_s.reshape(-1, shape[-3], shape[-2], shape[-1])
mask_n = mask_n.reshape(-1, shape[-3], shape[-2], shape[-1])
else:
mask_s = mask_s.reshape(-1, shape[-2], shape[-1])
mask_n = mask_n.reshape(-1, shape[-2], shape[-1])
psd_s = self.psd(specgram, mask_s) # (..., freq, time, channel, channel)
psd_n = self.psd(specgram, mask_n) # (..., freq, time, channel, channel)
......@@ -444,12 +330,9 @@ class MVDR(torch.nn.Module):
psd_s, psd_n, mask_s, mask_n, u, self.solution, self.diag_loading, self.diag_eps
)
else:
w_mvdr = self._get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
specgram_enhanced = self._apply_beamforming_vector(specgram, w_mvdr)
w_mvdr = _get_mvdr_vector(psd_s, psd_n, u, self.solution, self.diag_loading, self.diag_eps)
# unpack batch
specgram_enhanced = specgram_enhanced.reshape(shape[:-3] + shape[-2:])
specgram_enhanced = F.apply_beamforming(w_mvdr, specgram)
return specgram_enhanced.to(dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment